From ff5bc1bbd105297f1c6306a64f61e49748ec953c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 5 May 2025 23:26:39 -0700 Subject: [PATCH] refine free memory and bypass graph logic Signed-off-by: Wang, Yi A --- .../models/flash_causal_lm.py | 24 ++++++++++++++----- .../models/flash_vlm_causal_lm.py | 2 +- .../models/mllama_causal_lm.py | 13 ++++++---- .../text_generation_server/utils/dist.py | 2 +- .../utils/import_utils.py | 18 +++++++------- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 79626233..5aadff0d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1445,9 +1445,10 @@ class FlashCausalLM(Model): self.use_contiguous_pa = ( os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" ) - self.limit_hpu_graphs = ( - os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" + self.limit_hpu_graph = ( + os.environ.get("LIMIT_HPU_GRAPH", "true").lower() == "true" ) + self.max_seq_len_to_capture = 8192 super().__init__( model_id=model_id, model=model, @@ -1564,6 +1565,7 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) + self.max_batch_prefill_tokens = max_input_tokens * len(batch) max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: @@ -1592,15 +1594,23 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens + def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture): + if self.limit_hpu_graph: + return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture + return False + def warmup_hpu_graph(self, batch): warmup_times = 3 self.bucketing_ctx.generate_prompt_buckets() for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( @@ -1613,7 +1623,7 @@ class FlashCausalLM(Model): ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) def warmup_prefill( self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch @@ -1644,7 +1654,9 @@ class FlashCausalLM(Model): lm_head_indices = input_lengths - 1 kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1793,8 +1805,8 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) logits, speculative_logits = self.model.forward( diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 1776b219..a1a7ca4d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -453,7 +453,7 @@ class FlashVlmCausalLM(FlashCausalLM): ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) def forward( self, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 5de9bca8..dac65fea 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -325,7 +325,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + True, input_ids.shape[0] + ) self.model.forward( input_ids=_async_h2d_tensor_copy(input_ids), position_ids=_async_h2d_tensor_copy(position_ids), @@ -348,9 +350,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): for i, (batch_size, seq_len) in enumerate( reversed(self.bucketing_ctx.prompt_buckets) ): + if batch_size * seq_len > self.max_batch_prefill_tokens: + continue log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") for index in range(warmup_times): self.warmup_prefill(seq_len, batch_size, batch) + synchronize(self.device) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) for i, (batch_size, block_num) in enumerate( reversed(self.bucketing_ctx.decode_buckets) @@ -362,7 +367,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): ) for index in range(warmup_times): self.warmup_decode(batch_size, block_num, batch) - synchronize(self.device) + synchronize(self.device) def forward( self, @@ -438,8 +443,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kwargs = {} if htorch.utils.internal.is_lazy(): - kwargs["bypass_hpu_graphs"] = ( - batch.prefilling if self.limit_hpu_graphs else False + kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs( + batch.prefilling, input_ids.shape[0] ) if batch.prefill_cache_indices is not None: slots_pad = torch.zeros_like(input_ids) diff --git a/backends/gaudi/server/text_generation_server/utils/dist.py b/backends/gaudi/server/text_generation_server/utils/dist.py index 1c45713e..9866710b 100644 --- a/backends/gaudi/server/text_generation_server/utils/dist.py +++ b/backends/gaudi/server/text_generation_server/utils/dist.py @@ -7,7 +7,7 @@ from loguru import logger # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) -MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) +MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9")) class FakeBarrier: diff --git a/backends/gaudi/server/text_generation_server/utils/import_utils.py b/backends/gaudi/server/text_generation_server/utils/import_utils.py index 22560dd7..39156140 100644 --- a/backends/gaudi/server/text_generation_server/utils/import_utils.py +++ b/backends/gaudi/server/text_generation_server/utils/import_utils.py @@ -1,17 +1,19 @@ import torch from loguru import logger +import habana_frameworks.torch as htorch +import os def get_hpu_free_memory(device, memory_fraction): - from habana_frameworks.torch.hpu import memory_stats - - device_id = device.index - mem_stats = memory_stats(device_id) - logger.info(f"mem_stats: {mem_stats}") - total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"] - free_memory = max( - 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"]) + graph_reserved_mem = ( + float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1")) + if htorch.utils.internal.is_lazy() + else 0 ) + free_memory = int( + torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem) + ) + logger.info(f"Free memory on device {device}: {free_memory} bytes, ") return free_memory