Allocate top_n_token tensor in Batch

This commit is contained in:
Vincent Brouwers 2023-07-31 13:09:45 +00:00 committed by Nicolas Patry
parent 65c0d9c19d
commit 1b5fdf7000
5 changed files with 49 additions and 18 deletions

View File

@ -44,6 +44,7 @@ class CausalLMBatch(Batch):
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int] top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -125,6 +126,7 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -143,6 +145,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -230,6 +233,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
@ -243,6 +247,7 @@ class CausalLMBatch(Batch):
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -278,6 +283,7 @@ class CausalLMBatch(Batch):
attention_mask = None attention_mask = None
position_ids = None position_ids = None
past_key_values = [] past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
@ -320,6 +326,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length left_offset = max_input_length - batch.max_input_length
@ -449,6 +461,7 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -561,7 +574,7 @@ class CausalLM(Model):
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1) batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
) )
# Zipped iterator # Zipped iterator
@ -594,7 +607,7 @@ class CausalLM(Model):
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = self.decode_top_tokens( top_tokens = self.decode_top_tokens(
input_ids=all_input_ids.view(1, -1).tolist(), input_ids=all_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids, top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs, top_token_logprobs=top_token_logprobs,

View File

@ -168,6 +168,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int] top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch # Number of blocks in this batch
blocks: int blocks: int
@ -357,6 +358,7 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -384,6 +386,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -496,6 +499,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -528,6 +532,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -576,6 +581,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = [] start_slots = []
block_tables = [] block_tables = []
@ -613,6 +621,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[ all_input_ids_tensor[
@ -680,6 +689,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -850,7 +860,7 @@ class FlashCausalLM(Model):
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
) )
if prefill: if prefill:

View File

@ -108,14 +108,15 @@ class Model(ABC):
new_sequences, skip_special_tokens=False new_sequences, skip_special_tokens=False
) )
prefix_len = len(prefix_text)
results = [] results = []
for new_text in new_texts: for new_text in new_texts:
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"): if len(new_text) > prefix_len and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence # utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization. # from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated # If it's in the middle, it's probably a real invalid id generated
# by the model # by the model
new_text = new_text[len(prefix_text) :] new_text = new_text[prefix_len:]
results.append((new_text, read_offset, len(input_ids) + 1)) results.append((new_text, read_offset, len(input_ids) + 1))
else: else:
results.append(("", prefix_offset, read_offset)) results.append(("", prefix_offset, read_offset))

View File

@ -50,6 +50,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int] top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -129,6 +130,7 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0) prefix_offsets.append(0)
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -150,6 +152,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -245,6 +248,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = ( max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length) len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens + remaining_decode_tokens
@ -261,6 +265,7 @@ class Seq2SeqLMBatch(Batch):
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
@ -304,6 +309,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
encoder_last_hidden_state = None encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = [] past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
@ -393,6 +399,12 @@ class Seq2SeqLMBatch(Batch):
), ),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, : start_index:end_index, -batch.max_input_length :, :
@ -498,6 +510,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -624,7 +637,7 @@ class Seq2SeqLM(Model):
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, torch.softmax(logits[:, -1], -1) batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1)
) )
# Finished requests # Finished requests
@ -663,7 +676,7 @@ class Seq2SeqLM(Model):
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
top_tokens = self.decode_top_tokens( top_tokens = self.decode_top_tokens(
input_ids=all_decoder_input_ids.view(1, -1).tolist(), input_ids=all_decoder_input_ids.view(-1).tolist(),
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_token_ids=top_token_ids, top_token_ids=top_token_ids,
top_token_logprobs=top_token_logprobs, top_token_logprobs=top_token_logprobs,

View File

@ -337,20 +337,16 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: list[int], logprobs: torch.Tensor top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]: ) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned. remaining tokens are also returned.
""" """
# Do this as early as possible to mitigate copy latency max_top_n = max(top_n_tokens)
top_n_tensor = torch.tensor(top_n_tokens).to(
device=logprobs.device, non_blocking=True
)
# Early exit when top_n_tokens is not used # Early exit when top_n_tokens is not used
if max(top_n_tokens) == 0: if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
@ -358,11 +354,9 @@ def batch_top_tokens(
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk( sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
logprobs, k=max(top_n_tokens), dim=1, sorted=True
).values # .cpu()
nth_highest = torch.gather( nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1) sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
) )
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min