mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-18 23:32:06 +00:00
remove unnecessage input_id pad
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
151d6638d3
commit
79ee5135e3
@ -699,7 +699,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
def concatenate(
|
||||
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
|
||||
) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -748,7 +750,10 @@ class FlashCausalLMBatch(Batch):
|
||||
adapter_meta = None
|
||||
adapter_segment_builder = None
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if padded_total_bs == batches[0].input_ids.shape[0]:
|
||||
input_ids = batches[0].input_ids
|
||||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
if (
|
||||
batches[0].position_ids is not None
|
||||
and batches[0].position_ids.dim() == 2
|
||||
@ -827,7 +832,9 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
if i > 0:
|
||||
all_input_ids_tensor.index_copy_(
|
||||
0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :]
|
||||
0,
|
||||
index.to(batch.all_input_ids_tensor.device),
|
||||
batch.all_input_ids_tensor[:valid_bsize, :],
|
||||
)
|
||||
|
||||
block_tables_tensor[
|
||||
@ -848,9 +855,10 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
|
||||
input_ids.index_copy_(
|
||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||
)
|
||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||
slot_indices.index_copy_(
|
||||
0, index, batch.slot_indices + cumulative_slots
|
||||
@ -2042,6 +2050,7 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
)
|
||||
if batch.valid_indices is not None:
|
||||
# TODO speculative decoding handling missing
|
||||
next_token_logprobs = next_token_logprobs.cpu()
|
||||
accepted_ids = accepted_ids.cpu()
|
||||
index = torch.arange(
|
||||
@ -2052,7 +2061,13 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids_tensor.index_copy_(
|
||||
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
||||
)
|
||||
next_input_ids = next_input_ids[batch.valid_indices]
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
len(batch.valid_indices)
|
||||
)
|
||||
next_input_ids.index_copy_(
|
||||
0, index, next_input_ids[batch.valid_indices]
|
||||
)
|
||||
next_input_ids = next_input_ids[:padded_total_bs]
|
||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||
accepted_ids = accepted_ids[batch.valid_indices]
|
||||
if speculative_ids is not None:
|
||||
@ -2122,10 +2137,13 @@ class FlashCausalLM(Model):
|
||||
batch.slot_indices += accepted_ids[: len(batch)]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
index = F.pad(
|
||||
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
|
||||
)
|
||||
index = index.to(batch.all_input_ids_tensor.device)
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
index.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.all_input_ids_tensor.device,
|
||||
)
|
||||
@ -2213,7 +2231,18 @@ class FlashCausalLM(Model):
|
||||
htorch.core.mark_step()
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
if self.bucketing_ctx is not None:
|
||||
total_batch_size = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||
total_batch_size
|
||||
)
|
||||
batch = self.batch_type.concatenate(
|
||||
batches, padded_total_bs=padded_total_bs
|
||||
)
|
||||
else:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
else:
|
||||
batch = batches[0]
|
||||
prefill = batch.prefilling
|
||||
|
Loading…
Reference in New Issue
Block a user