remove unnecessage input_id pad

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-02 23:47:23 -07:00
parent 151d6638d3
commit 79ee5135e3

View File

@ -699,7 +699,9 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
def concatenate(
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -748,7 +750,10 @@ class FlashCausalLMBatch(Batch):
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if padded_total_bs == batches[0].input_ids.shape[0]:
input_ids = batches[0].input_ids
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
@ -827,7 +832,9 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
if i > 0:
all_input_ids_tensor.index_copy_(
0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :]
0,
index.to(batch.all_input_ids_tensor.device),
batch.all_input_ids_tensor[:valid_bsize, :],
)
block_tables_tensor[
@ -848,9 +855,10 @@ class FlashCausalLMBatch(Batch):
)
if not prefilling:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
@ -2042,6 +2050,7 @@ class FlashCausalLM(Model):
accepted_ids,
)
if batch.valid_indices is not None:
# TODO speculative decoding handling missing
next_token_logprobs = next_token_logprobs.cpu()
accepted_ids = accepted_ids.cpu()
index = torch.arange(
@ -2052,7 +2061,13 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor.index_copy_(
0, index, batch.all_input_ids_tensor[batch.valid_indices]
)
next_input_ids = next_input_ids[batch.valid_indices]
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
len(batch.valid_indices)
)
next_input_ids.index_copy_(
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs = next_token_logprobs[batch.valid_indices]
accepted_ids = accepted_ids[batch.valid_indices]
if speculative_ids is not None:
@ -2122,10 +2137,13 @@ class FlashCausalLM(Model):
batch.slot_indices += accepted_ids[: len(batch)]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = F.pad(
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
)
index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange(
0,
batch.all_input_ids_tensor.shape[0],
index.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
@ -2213,7 +2231,18 @@ class FlashCausalLM(Model):
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
batch = self.batch_type.concatenate(batches)
if self.bucketing_ctx is not None:
total_batch_size = 0
for b in batches:
total_batch_size += len(b)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
total_batch_size
)
batch = self.batch_type.concatenate(
batches, padded_total_bs=padded_total_bs
)
else:
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
prefill = batch.prefilling