diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 75d6be33..7aff73b8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -734,7 +734,7 @@ class FlashCausalLM(Model): # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.cuda.synchronize(self.device) - peak_memory = torch.cuda.max_memory_allocated(self.device) + peak_memory = torch.cuda.max_memory_reserved(self.device) dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size @@ -742,12 +742,8 @@ class FlashCausalLM(Model): total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory - # FIXME: - # remove wiggle room - # when world size > 1, some aggregation ops end up taking more memory than expected - safety = 1 - (0.02 * self.world_size) num_blocks = ( - int((total_gpu_memory * safety - peak_memory) // total_cache_size) + int((total_gpu_memory - peak_memory) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + batch.blocks ) @@ -755,7 +751,6 @@ class FlashCausalLM(Model): del CACHE_MANAGER del batch torch.cuda.empty_cache() - torch.cuda.synchronize(self.device) CACHE_MANAGER = CacheManager( num_blocks,