refine free memory and bypass graph logic

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-05 23:26:39 -07:00
parent 1cda91135e
commit ff5bc1bbd1
5 changed files with 39 additions and 20 deletions

View File

@ -1445,9 +1445,10 @@ class FlashCausalLM(Model):
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
)
self.limit_hpu_graphs = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true"
self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPH", "true").lower() == "true"
)
self.max_seq_len_to_capture = 8192
super().__init__(
model_id=model_id,
model=model,
@ -1564,6 +1565,7 @@ class FlashCausalLM(Model):
self.kv_cache_dtype,
self.device,
)
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
@ -1592,15 +1594,23 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
if self.limit_hpu_graph:
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
return False
def warmup_hpu_graph(self, batch):
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
@ -1613,7 +1623,7 @@ class FlashCausalLM(Model):
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
synchronize(self.device)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
@ -1644,7 +1654,9 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1793,8 +1805,8 @@ class FlashCausalLM(Model):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling, input_ids.shape[0]
)
logits, speculative_logits = self.model.forward(

View File

@ -453,7 +453,7 @@ class FlashVlmCausalLM(FlashCausalLM):
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
synchronize(self.device)
def forward(
self,

View File

@ -325,7 +325,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
@ -348,9 +350,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
@ -362,7 +367,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
synchronize(self.device)
def forward(
self,
@ -438,8 +443,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = (
batch.prefilling if self.limit_hpu_graphs else False
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling, input_ids.shape[0]
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)

View File

@ -7,7 +7,7 @@ from loguru import logger
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
class FakeBarrier:

View File

@ -1,17 +1,19 @@
import torch
from loguru import logger
import habana_frameworks.torch as htorch
import os
def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats
device_id = device.index
mem_stats = memory_stats(device_id)
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
free_memory = int(
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
)
logger.info(f"Free memory on device {device}: {free_memory} bytes, ")
return free_memory