mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-01 15:02:09 +00:00
refine log and fix some issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
a84da5b698
commit
8591687561
@ -1508,7 +1508,7 @@ class FlashCausalLM(Model):
|
||||
num_blocks * BLOCK_SIZE,
|
||||
)
|
||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true":
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
del _batch, batch
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
@ -1524,23 +1524,26 @@ class FlashCausalLM(Model):
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size)
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num)
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=torch.int64, device=self.device
|
||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
position_ids = torch.arange(
|
||||
prompt_len, dtype=torch.int32, device=self.device
|
||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||
block_tables = torch.arange(
|
||||
@ -1552,7 +1555,7 @@ class FlashCausalLM(Model):
|
||||
for b in block_tables[i]:
|
||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||
slot_acc.extend(slots[:prompt_len])
|
||||
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||
@ -1581,10 +1584,13 @@ class FlashCausalLM(Model):
|
||||
hpu_attention_meta=None,
|
||||
)
|
||||
|
||||
def warmup_decode(self, batch_size: int, block_num: int):
|
||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
|
||||
input_ids = torch.zeros(
|
||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||
)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
@ -1599,7 +1605,7 @@ class FlashCausalLM(Model):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -1725,7 +1731,6 @@ class FlashCausalLM(Model):
|
||||
padded_bs = input_lengths.shape[0]
|
||||
orig_bs = input_lengths.shape[0]
|
||||
if padded_bs != input_lengths.shape[0]:
|
||||
orig_bs = input_lengths.shape[0]
|
||||
padded_input_lengths = F.pad(
|
||||
input_lengths,
|
||||
(0, padded_bs - orig_bs),
|
||||
@ -1754,7 +1759,7 @@ class FlashCausalLM(Model):
|
||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||
)
|
||||
slots = F.pad(
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
lm_head_indices = F.pad(
|
||||
|
@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
def warmup_decode(
|
||||
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
||||
):
|
||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||
input_ids = torch.zeros(
|
||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||
)
|
||||
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
||||
# qwen2_vl and qwen2_5_vl case
|
||||
position_ids = position_ids.unsqueeze(-1).repeat(
|
||||
@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
lm_head_indices=None,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
lm_head_indices=None,
|
||||
pixel_values=None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes=None,
|
||||
image_grid_thw=None,
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||
@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
else:
|
||||
padded_bs = input_lengths.shape[0]
|
||||
orig_bs = input_lengths.shape[0]
|
||||
if padded_bs != input_lengths.shape[0]:
|
||||
orig_bs = input_lengths.shape[0]
|
||||
padded_input_lengths = F.pad(
|
||||
input_lengths,
|
||||
(0, padded_bs - orig_bs),
|
||||
@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
value=1,
|
||||
)
|
||||
slots = F.pad(
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
lm_head_indices = F.pad(
|
||||
@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
return logits, speculative_logits
|
||||
return logits[:orig_bs], (
|
||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||
)
|
||||
|
@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
def warmup_decode(
|
||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||
):
|
||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
||||
input_ids = torch.zeros(
|
||||
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||
)
|
||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||
blocks[0] += block_num % batch_size
|
||||
past_len = []
|
||||
@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
block_tables.append(block_array)
|
||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||
start_idx += blocks[i]
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
||||
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.tensor(
|
||||
past_len, dtype=torch.int32, device=self.device
|
||||
@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=torch.int64, device=self.device
|
||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
position_ids = torch.arange(
|
||||
prompt_len, dtype=torch.int32, device=self.device
|
||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||
block_tables = torch.arange(
|
||||
@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
for b in block_tables[i]:
|
||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||
slot_acc.extend(slots[:prompt_len])
|
||||
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||
@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
for i, (batch_size, seq_len) in enumerate(
|
||||
reversed(self.bucketing_ctx.prompt_buckets)
|
||||
):
|
||||
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||
for index in range(warmup_times):
|
||||
self.warmup_prefill(seq_len, batch_size, batch)
|
||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
for index in range(warmup_times):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
else:
|
||||
padded_bs = input_lengths.shape[0]
|
||||
orig_bs = input_lengths.shape[0]
|
||||
if padded_bs != input_lengths.shape[0]:
|
||||
orig_bs = input_lengths.shape[0]
|
||||
padded_input_lengths = F.pad(
|
||||
input_lengths,
|
||||
(0, padded_bs - orig_bs),
|
||||
@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||
)
|
||||
slots = F.pad(
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
lm_head_indices = F.pad(
|
||||
@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
)
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
return logits, speculative_logits
|
||||
return logits[:orig_bs], (
|
||||
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user