use max_memory_reserved

This commit is contained in:
OlivierDehaene 2023-07-18 18:06:46 +02:00
parent 1686a7c0dc
commit 36a9bddde4

View File

@ -734,7 +734,7 @@ class FlashCausalLM(Model):
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_allocated(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
@ -742,12 +742,8 @@ class FlashCausalLM(Model):
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# FIXME:
# remove wiggle room
# when world size > 1, some aggregation ops end up taking more memory than expected
safety = 1 - (0.02 * self.world_size)
num_blocks = (
int((total_gpu_memory * safety - peak_memory) // total_cache_size)
int((total_gpu_memory - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks
)
@ -755,7 +751,6 @@ class FlashCausalLM(Model):
del CACHE_MANAGER
del batch
torch.cuda.empty_cache()
torch.cuda.synchronize(self.device)
CACHE_MANAGER = CacheManager(
num_blocks,