use max_memory_reserved

This commit is contained in:
OlivierDehaene 2023-07-18 18:06:46 +02:00
parent 1686a7c0dc
commit 36a9bddde4

View File

@ -734,7 +734,7 @@ class FlashCausalLM(Model):
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize(self.device) torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_allocated(self.device) peak_memory = torch.cuda.max_memory_reserved(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
@ -742,12 +742,8 @@ class FlashCausalLM(Model):
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# FIXME:
# remove wiggle room
# when world size > 1, some aggregation ops end up taking more memory than expected
safety = 1 - (0.02 * self.world_size)
num_blocks = ( num_blocks = (
int((total_gpu_memory * safety - peak_memory) // total_cache_size) int((total_gpu_memory - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks + batch.blocks
) )
@ -755,7 +751,6 @@ class FlashCausalLM(Model):
del CACHE_MANAGER del CACHE_MANAGER
del batch del batch
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize(self.device)
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
num_blocks, num_blocks,