mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fix: skip cuda graphs that will oom and improve free memory logging
This commit is contained in:
parent
358ceb67dd
commit
8b4cd2a9fc
@ -1231,6 +1231,13 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
|
inital_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
|
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"Free memory before the warmup: {inital_free_memory/1024/1024:.2f} MB",
|
||||||
|
)
|
||||||
|
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
@ -1284,6 +1291,15 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# cuda graphs must fit within the new memory limit. In order to avoid an OOM, we
|
||||||
|
# need to exit early if there is not enough memory to fit a particular cuda graph
|
||||||
|
free_memory_post_alloc = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
|
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"Free memory after allocating the cache: {free_memory_post_alloc/1024/1024:.2f} MB",
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
if (
|
if (
|
||||||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
@ -1341,9 +1357,37 @@ class FlashCausalLM(Model):
|
|||||||
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}"
|
||||||
)
|
)
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
|
last_allocation_amount = 0
|
||||||
|
last_available_memory = free_memory_post_alloc
|
||||||
|
last_bs = 0
|
||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
|
expected_memory = int(
|
||||||
|
last_allocation_amount * (bs / last_bs if last_bs else 2)
|
||||||
|
)
|
||||||
|
if expected_memory > last_available_memory:
|
||||||
|
skipped_graphs = [str(k) for k in CUDA_GRAPHS if k <= bs]
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
f"Avoiding CUDA graph warmup for sizes {', '.join(skipped_graphs)} due to insufficient memory.",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
|
current_available_memory = get_free_memory(
|
||||||
|
self.device, MEMORY_FRACTION
|
||||||
|
)
|
||||||
|
last_allocation_amount = (
|
||||||
|
last_available_memory - current_available_memory
|
||||||
|
)
|
||||||
|
last_available_memory = current_available_memory
|
||||||
|
last_bs = bs
|
||||||
|
# report the total memory used
|
||||||
|
total_cuda_graph_memory = free_memory_post_alloc - last_available_memory
|
||||||
|
log_master(
|
||||||
|
logger.info,
|
||||||
|
f"Total memory used for CUDA graphs: {total_cuda_graph_memory/1024/1024:.2f} MB",
|
||||||
|
)
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception("Decode cuda graph warmup failed")
|
logger.exception("Decode cuda graph warmup failed")
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user