refine log and fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-02 19:11:35 -07:00
parent a84da5b698
commit 8591687561
3 changed files with 60 additions and 35 deletions

View File

@ -1508,7 +1508,7 @@ class FlashCausalLM(Model):
num_blocks * BLOCK_SIZE,
)
self.bucketing_ctx.num_hpu_blocks = num_blocks
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true":
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
@ -1524,23 +1524,26 @@ class FlashCausalLM(Model):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size)
self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num)
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
def warmup_prefill(self, prompt_len: int, bs: int):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device
prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs)
position_ids = torch.arange(
prompt_len, dtype=torch.int32, device=self.device
prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
block_tables = torch.arange(
@ -1552,7 +1555,7 @@ class FlashCausalLM(Model):
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
@ -1581,10 +1584,13 @@ class FlashCausalLM(Model):
hpu_attention_meta=None,
)
def warmup_decode(self, batch_size: int, block_num: int):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
@ -1599,7 +1605,7 @@ class FlashCausalLM(Model):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -1725,7 +1731,6 @@ class FlashCausalLM(Model):
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
@ -1754,7 +1759,7 @@ class FlashCausalLM(Model):
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(

View File

@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat(
@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM):
kv_cache=self.kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
pixel_values=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM):
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM):
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM):
value=1,
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)

View File

@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
tracer = trace.get_tracer(__name__)
@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
input_ids = torch.zeros(
batch_size, dtype=batch.input_ids.dtype, device=self.device
)
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
past_len = []
@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device
@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device
prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs)
position_ids = torch.arange(
prompt_len, dtype=torch.int32, device=self.device
prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
block_tables = torch.arange(
@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets)
):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets)
):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
else:
padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad(
input_lengths,
(0, padded_bs - orig_bs),
@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
)
if lm_head_indices is not None:
lm_head_indices = F.pad(
@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
)
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)