mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
use max_memory_reserved
This commit is contained in:
parent
1686a7c0dc
commit
36a9bddde4
@ -734,7 +734,7 @@ class FlashCausalLM(Model):
|
|||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
torch.cuda.synchronize(self.device)
|
torch.cuda.synchronize(self.device)
|
||||||
peak_memory = torch.cuda.max_memory_allocated(self.device)
|
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
||||||
|
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
@ -742,12 +742,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
# FIXME:
|
|
||||||
# remove wiggle room
|
|
||||||
# when world size > 1, some aggregation ops end up taking more memory than expected
|
|
||||||
safety = 1 - (0.02 * self.world_size)
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int((total_gpu_memory * safety - peak_memory) // total_cache_size)
|
int((total_gpu_memory - peak_memory) // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ batch.blocks
|
+ batch.blocks
|
||||||
)
|
)
|
||||||
@ -755,7 +751,6 @@ class FlashCausalLM(Model):
|
|||||||
del CACHE_MANAGER
|
del CACHE_MANAGER
|
||||||
del batch
|
del batch
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize(self.device)
|
|
||||||
|
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
Loading…
Reference in New Issue
Block a user