diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c8a8f36c..75d6be33 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -733,7 +733,7 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the # profiled peak memory. - torch.cuda.synchronize() + torch.cuda.synchronize(self.device) peak_memory = torch.cuda.max_memory_allocated(self.device) dtype_size = torch.tensor([], dtype=self.dtype).element_size() @@ -755,6 +755,7 @@ class FlashCausalLM(Model): del CACHE_MANAGER del batch torch.cuda.empty_cache() + torch.cuda.synchronize(self.device) CACHE_MANAGER = CacheManager( num_blocks,