mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
remove unnecessage input_id pad
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
151d6638d3
commit
79ee5135e3
@ -699,7 +699,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(
|
||||||
|
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
|
||||||
|
) -> "FlashCausalLMBatch":
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -748,7 +750,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_meta = None
|
adapter_meta = None
|
||||||
adapter_segment_builder = None
|
adapter_segment_builder = None
|
||||||
else:
|
else:
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
if padded_total_bs == batches[0].input_ids.shape[0]:
|
||||||
|
input_ids = batches[0].input_ids
|
||||||
|
else:
|
||||||
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
if (
|
if (
|
||||||
batches[0].position_ids is not None
|
batches[0].position_ids is not None
|
||||||
and batches[0].position_ids.dim() == 2
|
and batches[0].position_ids.dim() == 2
|
||||||
@ -827,7 +832,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
all_input_ids_tensor.index_copy_(
|
all_input_ids_tensor.index_copy_(
|
||||||
0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :]
|
0,
|
||||||
|
index.to(batch.all_input_ids_tensor.device),
|
||||||
|
batch.all_input_ids_tensor[:valid_bsize, :],
|
||||||
)
|
)
|
||||||
|
|
||||||
block_tables_tensor[
|
block_tables_tensor[
|
||||||
@ -848,9 +855,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not prefilling:
|
if not prefilling:
|
||||||
input_ids.index_copy_(
|
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
|
||||||
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
input_ids.index_copy_(
|
||||||
)
|
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
|
||||||
|
)
|
||||||
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
|
||||||
slot_indices.index_copy_(
|
slot_indices.index_copy_(
|
||||||
0, index, batch.slot_indices + cumulative_slots
|
0, index, batch.slot_indices + cumulative_slots
|
||||||
@ -2042,6 +2050,7 @@ class FlashCausalLM(Model):
|
|||||||
accepted_ids,
|
accepted_ids,
|
||||||
)
|
)
|
||||||
if batch.valid_indices is not None:
|
if batch.valid_indices is not None:
|
||||||
|
# TODO speculative decoding handling missing
|
||||||
next_token_logprobs = next_token_logprobs.cpu()
|
next_token_logprobs = next_token_logprobs.cpu()
|
||||||
accepted_ids = accepted_ids.cpu()
|
accepted_ids = accepted_ids.cpu()
|
||||||
index = torch.arange(
|
index = torch.arange(
|
||||||
@ -2052,7 +2061,13 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids_tensor.index_copy_(
|
batch.all_input_ids_tensor.index_copy_(
|
||||||
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
0, index, batch.all_input_ids_tensor[batch.valid_indices]
|
||||||
)
|
)
|
||||||
next_input_ids = next_input_ids[batch.valid_indices]
|
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
len(batch.valid_indices)
|
||||||
|
)
|
||||||
|
next_input_ids.index_copy_(
|
||||||
|
0, index, next_input_ids[batch.valid_indices]
|
||||||
|
)
|
||||||
|
next_input_ids = next_input_ids[:padded_total_bs]
|
||||||
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
next_token_logprobs = next_token_logprobs[batch.valid_indices]
|
||||||
accepted_ids = accepted_ids[batch.valid_indices]
|
accepted_ids = accepted_ids[batch.valid_indices]
|
||||||
if speculative_ids is not None:
|
if speculative_ids is not None:
|
||||||
@ -2122,10 +2137,13 @@ class FlashCausalLM(Model):
|
|||||||
batch.slot_indices += accepted_ids[: len(batch)]
|
batch.slot_indices += accepted_ids[: len(batch)]
|
||||||
else:
|
else:
|
||||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
|
index = F.pad(
|
||||||
|
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
|
||||||
|
)
|
||||||
index = index.to(batch.all_input_ids_tensor.device)
|
index = index.to(batch.all_input_ids_tensor.device)
|
||||||
batch_idx = torch.arange(
|
batch_idx = torch.arange(
|
||||||
0,
|
0,
|
||||||
batch.all_input_ids_tensor.shape[0],
|
index.shape[0],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=batch.all_input_ids_tensor.device,
|
device=batch.all_input_ids_tensor.device,
|
||||||
)
|
)
|
||||||
@ -2213,7 +2231,18 @@ class FlashCausalLM(Model):
|
|||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
# Stage 2. Prepare new batch for speculative scheduling
|
# Stage 2. Prepare new batch for speculative scheduling
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.batch_type.concatenate(batches)
|
if self.bucketing_ctx is not None:
|
||||||
|
total_batch_size = 0
|
||||||
|
for b in batches:
|
||||||
|
total_batch_size += len(b)
|
||||||
|
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
|
||||||
|
total_batch_size
|
||||||
|
)
|
||||||
|
batch = self.batch_type.concatenate(
|
||||||
|
batches, padded_total_bs=padded_total_bs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = self.batch_type.concatenate(batches)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
prefill = batch.prefilling
|
prefill = batch.prefilling
|
||||||
|
Loading…
Reference in New Issue
Block a user