refine log and fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-02 19:11:35 -07:00
parent a84da5b698
commit 8591687561
3 changed files with 60 additions and 35 deletions

View File

@ -1508,7 +1508,7 @@ class FlashCausalLM(Model):
num_blocks * BLOCK_SIZE, num_blocks * BLOCK_SIZE,
) )
self.bucketing_ctx.num_hpu_blocks = num_blocks self.bucketing_ctx.num_hpu_blocks = num_blocks
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true": if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
logger.info("skip warmup hpu graph, not recommmended") logger.info("skip warmup hpu graph, not recommmended")
del _batch, batch del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
@ -1524,23 +1524,26 @@ class FlashCausalLM(Model):
for i, (batch_size, seq_len) in enumerate( for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) reversed(self.bucketing_ctx.prompt_buckets)
): ):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size) self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
def warmup_prefill(self, prompt_len: int, bs: int): def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
input_ids = torch.zeros( input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs) ).repeat(bs)
position_ids = torch.arange( position_ids = torch.arange(
prompt_len, dtype=torch.int32, device=self.device prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs) ).repeat(bs)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
block_tables = torch.arange( block_tables = torch.arange(
@ -1552,7 +1555,7 @@ class FlashCausalLM(Model):
for b in block_tables[i]: for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len]) slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = ( input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
@ -1581,10 +1584,13 @@ class FlashCausalLM(Model):
hpu_attention_meta=None, hpu_attention_meta=None,
) )
def warmup_decode(self, batch_size: int, block_num: int): def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}") input_ids = torch.zeros(
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) batch_size, dtype=batch.input_ids.dtype, device=self.device
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) )
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
blocks = [block_num // batch_size for _ in range(batch_size)] blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size blocks[0] += block_num % batch_size
past_len = [] past_len = []
@ -1599,7 +1605,7 @@ class FlashCausalLM(Model):
block_tables.append(block_array) block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1) past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i] start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device) slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor( cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device past_len, dtype=torch.int32, device=self.device
@ -1725,7 +1731,6 @@ class FlashCausalLM(Model):
padded_bs = input_lengths.shape[0] padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]: if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad( padded_input_lengths = F.pad(
input_lengths, input_lengths,
(0, padded_bs - orig_bs), (0, padded_bs - orig_bs),
@ -1754,7 +1759,7 @@ class FlashCausalLM(Model):
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
) )
slots = F.pad( slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
) )
if lm_head_indices is not None: if lm_head_indices is not None:
lm_head_indices = F.pad( lm_head_indices = F.pad(

View File

@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM):
def warmup_decode( def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
): ):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}") input_ids = torch.zeros(
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) batch_size, dtype=batch.input_ids.dtype, device=self.device
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) )
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
if batch.position_ids is not None and batch.position_ids.dim() == 2: if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case # qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat( position_ids = position_ids.unsqueeze(-1).repeat(
@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM):
block_tables.append(block_array) block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1) past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i] start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device) slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor( cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device past_len, dtype=torch.int32, device=self.device
@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
pixel_values=None,
pixel_attention_mask=None,
image_sizes=None,
image_grid_thw=None,
) )
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch): def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM):
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM):
) )
else: else:
padded_bs = input_lengths.shape[0] padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]: if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad( padded_input_lengths = F.pad(
input_lengths, input_lengths,
(0, padded_bs - orig_bs), (0, padded_bs - orig_bs),
@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM):
value=1, value=1,
) )
slots = F.pad( slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
) )
if lm_head_indices is not None: if lm_head_indices is not None:
lm_head_indices = F.pad( lm_head_indices = F.pad(
@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM):
batch.image_sizes = None batch.image_sizes = None
if batch.image_grid_thw is not None: if batch.image_grid_thw is not None:
batch.image_grid_thw = None batch.image_grid_thw = None
return logits, speculative_logits return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)

View File

@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import (
synchronize, synchronize,
) )
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.log import log_master
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode( def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
): ):
logger.info(f"warmup decode bs {batch_size} block_num {block_num}") input_ids = torch.zeros(
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device) batch_size, dtype=batch.input_ids.dtype, device=self.device
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device) )
position_ids = torch.arange(
batch_size, dtype=batch.position_ids.dtype, device=self.device
)
blocks = [block_num // batch_size for _ in range(batch_size)] blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size blocks[0] += block_num % batch_size
past_len = [] past_len = []
@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
block_tables.append(block_array) block_tables.append(block_array)
past_len.append(blocks[i] * BLOCK_SIZE - 1) past_len.append(blocks[i] * BLOCK_SIZE - 1)
start_idx += blocks[i] start_idx += blocks[i]
slots = torch.tensor(slots, dtype=torch.int64, device=self.device) slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device) input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.tensor( cache_lengths_tensor = torch.tensor(
past_len, dtype=torch.int32, device=self.device past_len, dtype=torch.int32, device=self.device
@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
input_ids = torch.zeros( input_ids = torch.zeros(
prompt_len, dtype=torch.int64, device=self.device prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs) ).repeat(bs)
position_ids = torch.arange( position_ids = torch.arange(
prompt_len, dtype=torch.int32, device=self.device prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs) ).repeat(bs)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
block_tables = torch.arange( block_tables = torch.arange(
@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for b in block_tables[i]: for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len]) slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device) slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = ( input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
for i, (batch_size, seq_len) in enumerate( for i, (batch_size, seq_len) in enumerate(
reversed(self.bucketing_ctx.prompt_buckets) reversed(self.bucketing_ctx.prompt_buckets)
): ):
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch) self.warmup_prefill(seq_len, batch_size, batch)
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks) self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
for i, (batch_size, block_num) in enumerate( for i, (batch_size, block_num) in enumerate(
reversed(self.bucketing_ctx.decode_buckets) reversed(self.bucketing_ctx.decode_buckets)
): ):
log_master(
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
)
for index in range(warmup_times): for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
else: else:
padded_bs = input_lengths.shape[0] padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0]
if padded_bs != input_lengths.shape[0]: if padded_bs != input_lengths.shape[0]:
orig_bs = input_lengths.shape[0]
padded_input_lengths = F.pad( padded_input_lengths = F.pad(
input_lengths, input_lengths,
(0, padded_bs - orig_bs), (0, padded_bs - orig_bs),
@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
) )
slots = F.pad( slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1 slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
) )
if lm_head_indices is not None: if lm_head_indices is not None:
lm_head_indices = F.pad( lm_head_indices = F.pad(
@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
) )
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
return logits, speculative_logits return logits[:orig_bs], (
speculative_logits[:orig_bs] if speculative_logits is not None else None
)