From 36a9bddde4fd7ab69640e7fc2368ad4ec667e526 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jul 2023 18:06:46 +0200 Subject: [PATCH] use max_memory_reserved --- server/text_generation_server/models/flash_causal_lm.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 75d6be33..7aff73b8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -734,7 +734,7 @@ class FlashCausalLM(Model): # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.cuda.synchronize(self.device) - peak_memory = torch.cuda.max_memory_allocated(self.device) + peak_memory = torch.cuda.max_memory_reserved(self.device) dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size @@ -742,12 +742,8 @@ class FlashCausalLM(Model): total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory - # FIXME: - # remove wiggle room - # when world size > 1, some aggregation ops end up taking more memory than expected - safety = 1 - (0.02 * self.world_size) num_blocks = ( - int((total_gpu_memory * safety - peak_memory) // total_cache_size) + int((total_gpu_memory - peak_memory) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + batch.blocks ) @@ -755,7 +751,6 @@ class FlashCausalLM(Model): del CACHE_MANAGER del batch torch.cuda.empty_cache() - torch.cuda.synchronize(self.device) CACHE_MANAGER = CacheManager( num_blocks,