use less memory

This commit is contained in:
OlivierDehaene 2023-07-19 00:42:15 +02:00
parent 05d2a77e4c
commit 0111869ad0

View File

@ -743,8 +743,9 @@ class FlashCausalLM(Model):
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# 0.98 to add some wiggle room
num_blocks = ( num_blocks = (
int((total_gpu_memory - peak_memory) // total_cache_size) int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks + batch.blocks
) )