diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f3eee175..929361e6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -44,6 +44,7 @@ class CausalLMBatch(Batch): next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Metadata used for padding max_input_length: int @@ -125,6 +126,7 @@ class CausalLMBatch(Batch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -143,6 +145,7 @@ class CausalLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, @@ -230,6 +233,7 @@ class CausalLMBatch(Batch): layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens self.requests = requests @@ -243,6 +247,7 @@ class CausalLMBatch(Batch): self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens @@ -278,6 +283,7 @@ class CausalLMBatch(Batch): attention_mask = None position_ids = None past_key_values = [] + top_n_tokens_tensor = None # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes @@ -320,6 +326,12 @@ class CausalLMBatch(Batch): (total_batch_size, max_input_length + padding_right_offset), ) + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length @@ -449,6 +461,7 @@ class CausalLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, @@ -561,7 +574,7 @@ class CausalLM(Model): stopped = True batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, torch.softmax(logits[:, -1], -1) + batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) ) # Zipped iterator @@ -594,7 +607,7 @@ class CausalLM(Model): top_token_logprobs, ) in enumerate(iterator): top_tokens = self.decode_top_tokens( - input_ids=all_input_ids.view(1, -1).tolist(), + input_ids=all_input_ids.view(-1).tolist(), top_n_tokens=top_n_tokens, top_token_ids=top_token_ids, top_token_logprobs=top_token_logprobs, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dc62955f..672688c8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -168,6 +168,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Number of blocks in this batch blocks: int @@ -357,6 +358,7 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) return cls( batch_id=pb.id, @@ -384,6 +386,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -496,6 +499,7 @@ class FlashCausalLMBatch(Batch): input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) + top_n_tokens_tensor = self.top_n_tokens_tensor[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -528,6 +532,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -576,6 +581,9 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) start_slots = [] block_tables = [] @@ -613,6 +621,7 @@ class FlashCausalLMBatch(Batch): position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ @@ -680,6 +689,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -850,7 +860,7 @@ class FlashCausalLM(Model): ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, logprobs + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) if prefill: diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 94d9306f..a0b7e96d 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -108,14 +108,15 @@ class Model(ABC): new_sequences, skip_special_tokens=False ) + prefix_len = len(prefix_text) results = [] for new_text in new_texts: - if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + if len(new_text) > prefix_len and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. # If it's in the middle, it's probably a real invalid id generated # by the model - new_text = new_text[len(prefix_text) :] + new_text = new_text[prefix_len:] results.append((new_text, read_offset, len(input_ids) + 1)) else: results.append(("", prefix_offset, read_offset)) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f215e632..f5703ceb 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -50,6 +50,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Metadata used for padding max_input_length: int @@ -129,6 +130,7 @@ class Seq2SeqLMBatch(Batch): prefix_offsets.append(0) read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) + top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -150,6 +152,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, @@ -245,6 +248,7 @@ class Seq2SeqLMBatch(Batch): layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:] + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] max_tokens = ( len(request_ids) * (max_input_length + max_decoder_input_length) + remaining_decode_tokens @@ -261,6 +265,7 @@ class Seq2SeqLMBatch(Batch): self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor self.max_input_length = max_input_length self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset @@ -304,6 +309,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = None decoder_attention_mask = None encoder_last_hidden_state = None + top_n_tokens_tensor = None past_key_values = [] # Used for slicing correctly inside the tensors @@ -393,6 +399,12 @@ class Seq2SeqLMBatch(Batch): ), ) + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # Copy to correct indices encoder_last_hidden_state[ start_index:end_index, -batch.max_input_length :, : @@ -498,6 +510,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, @@ -624,7 +637,7 @@ class Seq2SeqLM(Model): ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( - batch.top_n_tokens, torch.softmax(logits[:, -1], -1) + batch.top_n_tokens, batch.top_n_tokens_tensor, torch.softmax(logits[:, -1], -1) ) # Finished requests @@ -663,7 +676,7 @@ class Seq2SeqLM(Model): top_token_logprobs, ) in enumerate(iterator): top_tokens = self.decode_top_tokens( - input_ids=all_decoder_input_ids.view(1, -1).tolist(), + input_ids=all_decoder_input_ids.view(-1).tolist(), top_n_tokens=top_n_tokens, top_token_ids=top_token_ids, top_token_logprobs=top_token_logprobs, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index fe40b338..69177d56 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -337,20 +337,16 @@ class HeterogeneousSampling: def batch_top_tokens( - top_n_tokens: list[int], logprobs: torch.Tensor + top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor ) -> Tuple[List[List[int]], List[List[float]]]: """Find the top n most likely tokens for a batch of generations. When multiple tokens have equal probabilities and they don't all fit, the remaining tokens are also returned. """ - # Do this as early as possible to mitigate copy latency - top_n_tensor = torch.tensor(top_n_tokens).to( - device=logprobs.device, non_blocking=True - ) - + max_top_n = max(top_n_tokens) # Early exit when top_n_tokens is not used - if max(top_n_tokens) == 0: + if max_top_n == 0: return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) # Ensure top_n doesn't exceed vocab size @@ -358,11 +354,9 @@ def batch_top_tokens( # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Sorted topk is faster than torch.sort() since we only need a small subset - sorted_top_k = torch.topk( - logprobs, k=max(top_n_tokens), dim=1, sorted=True - ).values # .cpu() + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values nth_highest = torch.gather( - sorted_top_k, 1, (top_n_tensor - 1).clip(min=0).unsqueeze(1) + sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) ) nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min