refine free memory and bypass graph logic

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-05 23:26:39 -07:00
parent 1cda91135e
commit ff5bc1bbd1
5 changed files with 39 additions and 20 deletions

View File

@ -1445,9 +1445,10 @@ class FlashCausalLM(Model):
self.use_contiguous_pa = ( self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true" os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
) )
self.limit_hpu_graphs = ( self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPHS", "false").lower() == "true" os.environ.get("LIMIT_HPU_GRAPH", "true").lower() == "true"
) )
self.max_seq_len_to_capture = 8192
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
model=model, model=model,
@ -1564,6 +1565,7 @@ class FlashCausalLM(Model):
self.kv_cache_dtype, self.kv_cache_dtype,
self.device, self.device,
) )
self.max_batch_prefill_tokens = max_input_tokens * len(batch)
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128)) max_num_seqs = int(os.getenv("MAX_BATCH_SIZE", 128))
if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None: if os.getenv("VLLM_PROMPT_SEQ_BUCKET_MAX") is None:
@ -1592,15 +1594,23 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
if self.limit_hpu_graph:
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
return False
def warmup_hpu_graph(self, batch): def warmup_hpu_graph(self, batch):
warmup_times = 3 warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets() self.bucketing_ctx.generate_prompt_buckets()
for i, (batch_size, seq_len) in enumerate( for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) reversed(self.bucketing_ctx.prompt_buckets)
): ):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
@ -1644,7 +1654,9 @@ class FlashCausalLM(Model):
lm_head_indices = input_lengths - 1 lm_head_indices = input_lengths - 1
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1793,8 +1805,8 @@ class FlashCausalLM(Model):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling if self.limit_hpu_graphs else False batch.prefilling, input_ids.shape[0]
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(

View File

@ -325,7 +325,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = self.limit_hpu_graphs kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
True, input_ids.shape[0]
)
self.model.forward( self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids), input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids), position_ids=_async_h2d_tensor_copy(position_ids),
@ -348,9 +350,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate( for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) reversed(self.bucketing_ctx.prompt_buckets)
): ):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}") log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
@ -438,8 +443,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kwargs = {} kwargs = {}
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = ( kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
batch.prefilling if self.limit_hpu_graphs else False batch.prefilling, input_ids.shape[0]
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids) slots_pad = torch.zeros_like(input_ids)

View File

@ -7,7 +7,7 @@ from loguru import logger
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8")) MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.9"))
class FakeBarrier: class FakeBarrier:

View File

@ -1,17 +1,19 @@
import torch import torch
from loguru import logger from loguru import logger
import habana_frameworks.torch as htorch
import os
def get_hpu_free_memory(device, memory_fraction): def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
device_id = device.index if htorch.utils.internal.is_lazy()
mem_stats = memory_stats(device_id) else 0
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max(
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
) )
free_memory = int(
torch.hpu.mem_get_info()[0] * memory_fraction * (1 - graph_reserved_mem)
)
logger.info(f"Free memory on device {device}: {free_memory} bytes, ")
return free_memory return free_memory