mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Allocate top_n_token tensor in Batch
This commit is contained in:
parent
65c0d9c19d
commit
1b5fdf7000
@ -44,6 +44,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -125,6 +126,7 @@ class CausalLMBatch(Batch):
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -143,6 +145,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
@ -230,6 +233,7 @@ class CausalLMBatch(Batch):
|
||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
@ -243,6 +247,7 @@ class CausalLMBatch(Batch):
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
@ -278,6 +283,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask = None
|
||||
position_ids = None
|
||||
past_key_values = []
|
||||
top_n_tokens_tensor = None
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
@ -320,6 +326,12 @@ class CausalLMBatch(Batch):
|
||||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_input_length - batch.max_input_length
|
||||
@ -449,6 +461,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
@ -561,7 +574,7 @@ class CausalLM(Model):
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||
)
|
||||
|
||||
# Zipped iterator
|
||||
@ -594,7 +607,7 @@ class CausalLM(Model):
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_input_ids.view(1, -1).tolist(),
|
||||
input_ids=all_input_ids.view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
|
@ -168,6 +168,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
@ -357,6 +358,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -384,6 +386,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -496,6 +499,7 @@ class FlashCausalLMBatch(Batch):
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -528,6 +532,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -576,6 +581,9 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
@ -613,6 +621,7 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
all_input_ids_tensor[
|
||||
@ -680,6 +689,7 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -850,7 +860,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, logprobs
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
||||
if prefill:
|
||||
|
@ -108,14 +108,15 @@ class Model(ABC):
|
||||
new_sequences, skip_special_tokens=False
|
||||
)
|
||||
|
||||
prefix_len = len(prefix_text)
|
||||
results = []
|
||||
for new_text in new_texts:
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[len(prefix_text) :]
|
||||
new_text = new_text[prefix_len:]
|
||||
results.append((new_text, read_offset, len(input_ids) + 1))
|
||||
else:
|
||||
results.append(("", prefix_offset, read_offset))
|
||||
|
@ -50,6 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -129,6 +130,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
prefix_offsets.append(0)
|
||||
read_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -150,6 +152,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -245,6 +248,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = (
|
||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
@ -261,6 +265,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
@ -304,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
decoder_input_ids = None
|
||||
decoder_attention_mask = None
|
||||
encoder_last_hidden_state = None
|
||||
top_n_tokens_tensor = None
|
||||
past_key_values = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
@ -393,6 +399,12 @@ class Seq2SeqLMBatch(Batch):
|
||||
),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# Copy to correct indices
|
||||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
@ -498,6 +510,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -624,7 +637,7 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
|
||||
)
|
||||
|
||||
# Finished requests
|
||||
@ -663,7 +676,7 @@ class Seq2SeqLM(Model):
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
top_tokens = self.decode_top_tokens(
|
||||
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
|
||||
input_ids=all_decoder_input_ids.view(-1).tolist(),
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_token_ids=top_token_ids,
|
||||
top_token_logprobs=top_token_logprobs,
|
||||
|
@ -337,20 +337,16 @@ class HeterogeneousSampling:
|
||||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: list[int], logprobs: torch.Tensor
|
||||
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
remaining tokens are also returned.
|
||||
"""
|
||||
# Do this as early as possible to mitigate copy latency
|
||||
top_n_tensor = torch.tensor(top_n_tokens).to(
|
||||
device=logprobs.device, non_blocking=True
|
||||
)
|
||||
|
||||
max_top_n = max(top_n_tokens)
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max(top_n_tokens) == 0:
|
||||
if max_top_n == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
@ -358,11 +354,9 @@ def batch_top_tokens(
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
sorted_top_k = torch.topk(
|
||||
logprobs, k=max(top_n_tokens), dim=1, sorted=True
|
||||
).values # .cpu()
|
||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||
nth_highest = torch.gather(
|
||||
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
)
|
||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user