add syncs

This commit is contained in:
OlivierDehaene 2023-07-18 17:03:29 +02:00
parent 160a50af77
commit 1686a7c0dc

View File

@ -733,7 +733,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize() torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_allocated(self.device) peak_memory = torch.cuda.max_memory_allocated(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.dtype).element_size()
@ -755,6 +755,7 @@ class FlashCausalLM(Model):
del CACHE_MANAGER del CACHE_MANAGER
del batch del batch
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize(self.device)
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
num_blocks, num_blocks,