Allocate top_n_token tensor in Batch

This commit is contained in:
Vincent Brouwers 2023-07-31 13:09:45 +00:00
parent 95d0fba7de
commit d16298b8d4
5 changed files with 49 additions and 18 deletions

View File

@ -44,6 +44,7 @@ class CausalLMBatch(Batch):
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -125,6 +126,7 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -143,6 +145,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
@ -230,6 +233,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
@ -243,6 +247,7 @@ class CausalLMBatch(Batch):
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
@ -278,6 +283,7 @@ class CausalLMBatch(Batch):
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
@ -320,6 +326,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
@ -449,6 +461,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -561,7 +574,7 @@ class CausalLM(Model):
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
)
# Zipped iterator
@ -594,7 +607,7 @@ class CausalLM(Model):
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_input_ids.view(1, -1).tolist(),
input_ids=all_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,

View File

@ -168,6 +168,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch
blocks: int
@ -357,6 +358,7 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
return cls(
batch_id=pb.id,
@ -384,6 +386,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -496,6 +499,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -528,6 +532,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -576,6 +581,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = []
block_tables = []
@ -613,6 +621,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[
@ -680,6 +689,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -850,7 +860,7 @@ class FlashCausalLM(Model):
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, logprobs
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill:

View File

@ -108,14 +108,15 @@ class Model(ABC):
new_sequences, skip_special_tokens=False
)
prefix_len = len(prefix_text)
results = []
for new_text in new_texts:
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
new_text = new_text[prefix_len:]
results.append((new_text, read_offset, len(input_ids) + 1))
else:
results.append(("", prefix_offset, read_offset))

View File

@ -50,6 +50,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -129,6 +130,7 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -150,6 +152,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
@ -245,6 +248,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
@ -261,6 +265,7 @@ class Seq2SeqLMBatch(Batch):
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
@ -304,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = []
# Used for slicing correctly inside the tensors
@ -393,6 +399,12 @@ class Seq2SeqLMBatch(Batch):
),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
@ -498,6 +510,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -624,7 +637,7 @@ class Seq2SeqLM(Model):
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1)
batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
)
# Finished requests
@ -663,7 +676,7 @@ class Seq2SeqLM(Model):
top_token_logprobs,
) in enumerate(iterator):
top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids.view(1, -1).tolist(),
input_ids=all_decoder_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs,

View File

@ -337,20 +337,16 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: list[int], logprobs: torch.Tensor
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
# Do this as early as possible to mitigate copy latency
top_n_tensor = torch.tensor(top_n_tokens).to(
device=logprobs.device, non_blocking=True
)
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max(top_n_tokens) == 0:
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
@ -358,11 +354,9 @@ def batch_top_tokens(
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(
logprobs, k=max(top_n_tokens), dim=1, sorted=True
).values # .cpu()
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1)
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min