diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4a3ad6fb..517fba68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -743,8 +743,9 @@ class FlashCausalLM(Model): total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + # 0.98 to add some wiggle room num_blocks = ( - int((total_gpu_memory - peak_memory) // total_cache_size) + int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + batch.blocks )