mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-01 15:02:09 +00:00
refine log and fix some issue
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
a84da5b698
commit
8591687561
@ -1508,7 +1508,7 @@ class FlashCausalLM(Model):
|
|||||||
num_blocks * BLOCK_SIZE,
|
num_blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
self.bucketing_ctx.num_hpu_blocks = num_blocks
|
||||||
if os.getenv("SKIP_WARMUP_GRAPH", "false").lower() == "true":
|
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||||
logger.info("skip warmup hpu graph, not recommmended")
|
logger.info("skip warmup hpu graph, not recommmended")
|
||||||
del _batch, batch
|
del _batch, batch
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
@ -1524,23 +1524,26 @@ class FlashCausalLM(Model):
|
|||||||
for i, (batch_size, seq_len) in enumerate(
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
):
|
):
|
||||||
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int):
|
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
|
||||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
prompt_len, dtype=torch.int64, device=self.device
|
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(bs)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
prompt_len, dtype=torch.int32, device=self.device
|
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(bs)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
@ -1552,7 +1555,7 @@ class FlashCausalLM(Model):
|
|||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
slot_acc.extend(slots[:prompt_len])
|
slot_acc.extend(slots[:prompt_len])
|
||||||
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
@ -1581,10 +1584,13 @@ class FlashCausalLM(Model):
|
|||||||
hpu_attention_meta=None,
|
hpu_attention_meta=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_decode(self, batch_size: int, block_num: int):
|
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
|
||||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
input_ids = torch.zeros(
|
||||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
blocks[0] += block_num % batch_size
|
blocks[0] += block_num % batch_size
|
||||||
past_len = []
|
past_len = []
|
||||||
@ -1599,7 +1605,7 @@ class FlashCausalLM(Model):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -1725,7 +1731,6 @@ class FlashCausalLM(Model):
|
|||||||
padded_bs = input_lengths.shape[0]
|
padded_bs = input_lengths.shape[0]
|
||||||
orig_bs = input_lengths.shape[0]
|
orig_bs = input_lengths.shape[0]
|
||||||
if padded_bs != input_lengths.shape[0]:
|
if padded_bs != input_lengths.shape[0]:
|
||||||
orig_bs = input_lengths.shape[0]
|
|
||||||
padded_input_lengths = F.pad(
|
padded_input_lengths = F.pad(
|
||||||
input_lengths,
|
input_lengths,
|
||||||
(0, padded_bs - orig_bs),
|
(0, padded_bs - orig_bs),
|
||||||
@ -1754,7 +1759,7 @@ class FlashCausalLM(Model):
|
|||||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||||
)
|
)
|
||||||
slots = F.pad(
|
slots = F.pad(
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
lm_head_indices = F.pad(
|
lm_head_indices = F.pad(
|
||||||
|
@ -383,9 +383,12 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
def warmup_decode(
|
def warmup_decode(
|
||||||
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
|
||||||
):
|
):
|
||||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
input_ids = torch.zeros(
|
||||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
if batch.position_ids is not None and batch.position_ids.dim() == 2:
|
||||||
# qwen2_vl and qwen2_5_vl case
|
# qwen2_vl and qwen2_5_vl case
|
||||||
position_ids = position_ids.unsqueeze(-1).repeat(
|
position_ids = position_ids.unsqueeze(-1).repeat(
|
||||||
@ -405,7 +408,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -438,9 +441,12 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=None,
|
|
||||||
adapter_data=None,
|
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
lm_head_indices=None,
|
||||||
|
pixel_values=None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes=None,
|
||||||
|
image_grid_thw=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
|
||||||
@ -450,6 +456,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
@ -546,8 +555,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
padded_bs = input_lengths.shape[0]
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
if padded_bs != input_lengths.shape[0]:
|
if padded_bs != input_lengths.shape[0]:
|
||||||
orig_bs = input_lengths.shape[0]
|
|
||||||
padded_input_lengths = F.pad(
|
padded_input_lengths = F.pad(
|
||||||
input_lengths,
|
input_lengths,
|
||||||
(0, padded_bs - orig_bs),
|
(0, padded_bs - orig_bs),
|
||||||
@ -586,7 +595,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
value=1,
|
value=1,
|
||||||
)
|
)
|
||||||
slots = F.pad(
|
slots = F.pad(
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
lm_head_indices = F.pad(
|
lm_head_indices = F.pad(
|
||||||
@ -621,4 +630,6 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return logits, speculative_logits
|
return logits[:orig_bs], (
|
||||||
|
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||||
|
)
|
||||||
|
@ -27,6 +27,7 @@ from text_generation_server.utils.import_utils import (
|
|||||||
synchronize,
|
synchronize,
|
||||||
)
|
)
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -208,9 +209,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
def warmup_decode(
|
def warmup_decode(
|
||||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||||
):
|
):
|
||||||
logger.info(f"warmup decode bs {batch_size} block_num {block_num}")
|
input_ids = torch.zeros(
|
||||||
input_ids = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
|
batch_size, dtype=batch.input_ids.dtype, device=self.device
|
||||||
position_ids = torch.arange(batch_size, dtype=torch.int32, device=self.device)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
batch_size, dtype=batch.position_ids.dtype, device=self.device
|
||||||
|
)
|
||||||
blocks = [block_num // batch_size for _ in range(batch_size)]
|
blocks = [block_num // batch_size for _ in range(batch_size)]
|
||||||
blocks[0] += block_num % batch_size
|
blocks[0] += block_num % batch_size
|
||||||
past_len = []
|
past_len = []
|
||||||
@ -225,7 +229,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
block_tables.append(block_array)
|
block_tables.append(block_array)
|
||||||
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
past_len.append(blocks[i] * BLOCK_SIZE - 1)
|
||||||
start_idx += blocks[i]
|
start_idx += blocks[i]
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slots, dtype=batch.slots.dtype, device=self.device)
|
||||||
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(batch_size, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.tensor(
|
cache_lengths_tensor = torch.tensor(
|
||||||
past_len, dtype=torch.int32, device=self.device
|
past_len, dtype=torch.int32, device=self.device
|
||||||
@ -266,12 +270,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
||||||
logger.info(f"warmup prefill seq {prompt_len} bs {bs}")
|
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
prompt_len, dtype=torch.int64, device=self.device
|
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(bs)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
prompt_len, dtype=torch.int32, device=self.device
|
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(bs)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
@ -283,7 +286,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
slot_acc.extend(slots[:prompt_len])
|
slot_acc.extend(slots[:prompt_len])
|
||||||
slots = torch.tensor(slot_acc, dtype=torch.int64, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
@ -320,12 +323,16 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
for i, (batch_size, seq_len) in enumerate(
|
for i, (batch_size, seq_len) in enumerate(
|
||||||
reversed(self.bucketing_ctx.prompt_buckets)
|
reversed(self.bucketing_ctx.prompt_buckets)
|
||||||
):
|
):
|
||||||
|
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_prefill(seq_len, batch_size, batch)
|
self.warmup_prefill(seq_len, batch_size, batch)
|
||||||
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
|
||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
|
)
|
||||||
for index in range(warmup_times):
|
for index in range(warmup_times):
|
||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
@ -425,8 +432,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
padded_bs = input_lengths.shape[0]
|
||||||
|
orig_bs = input_lengths.shape[0]
|
||||||
if padded_bs != input_lengths.shape[0]:
|
if padded_bs != input_lengths.shape[0]:
|
||||||
orig_bs = input_lengths.shape[0]
|
|
||||||
padded_input_lengths = F.pad(
|
padded_input_lengths = F.pad(
|
||||||
input_lengths,
|
input_lengths,
|
||||||
(0, padded_bs - orig_bs),
|
(0, padded_bs - orig_bs),
|
||||||
@ -455,7 +462,7 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||||
)
|
)
|
||||||
slots = F.pad(
|
slots = F.pad(
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=-1
|
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
lm_head_indices = F.pad(
|
lm_head_indices = F.pad(
|
||||||
@ -484,4 +491,6 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
)
|
)
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return logits, speculative_logits
|
return logits[:orig_bs], (
|
||||||
|
speculative_logits[:orig_bs] if speculative_logits is not None else None
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user