diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 26bce585..0978fc81 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,5 +1,6 @@ import math import itertools +from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens import torch import torch.distributed @@ -16,6 +17,7 @@ from text_generation_server.models.types import ( PrefillTokens, Generation, GeneratedText, + TopToken, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -165,6 +167,7 @@ class FlashCausalLMBatch(Batch): # Generation helpers next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] # Number of blocks in this batch blocks: int @@ -217,6 +220,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters = [] stopping_criterias = [] + top_n_tokens = [] # Cumulative length cumulative_length = 0 @@ -259,6 +263,7 @@ class FlashCausalLMBatch(Batch): ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) # Paged attention # Remove one as the first token des not have a past @@ -378,6 +383,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, blocks=blocks, max_blocks=max_blocks, ) @@ -417,6 +423,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] stopping_criterias = [] + top_n_tokens = [] blocks = 0 max_blocks = 0 @@ -443,6 +450,8 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -518,6 +527,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, blocks=blocks, max_blocks=max_blocks, ) @@ -577,6 +587,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters = [] stopping_criterias = [] + top_n_tokens = [] # Cumulative length cumulative_batch_size = 0 @@ -624,6 +635,8 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(top_n_tokens) + # Update cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) @@ -666,6 +679,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, blocks=blocks, max_blocks=max_blocks, ) @@ -831,10 +845,14 @@ class FlashCausalLM(Model): else: next_token_logits = out - next_input_ids, next_token_logprobs = batch.next_token_chooser( + next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, logprobs + ) + if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -933,6 +951,8 @@ class FlashCausalLM(Model): batch.next_token_chooser.seeds, next_token_ids, next_token_logprobs, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -947,7 +967,25 @@ class FlashCausalLM(Model): seed, next_token_id, next_token_logprob, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): + top_tokens = [] + for token_id, token_logprob in zip(top_token_ids, top_token_logprobs): + tok_itm = token_id + top_tokens.append( + TopToken( + token_id=token_id, + token_logprob=token_logprob, + token_text=self.decode_token( + all_input_ids=all_input_ids + [tok_itm], + prefix_offset=prefix_offset, + read_offset=read_offset, + )[0], + token_is_special=tok_itm in self.all_special_ids, + ) + ) + # Append next token to all tokens all_input_ids.append(next_token_id) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index a50370de..0682959b 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -230,11 +230,10 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) next_ids = self.choice(scores) - next_logprobs = torch.gather( - torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) - ).view(-1) + logprobs = torch.log_softmax(scores, -1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - return next_ids, next_logprobs + return next_ids, next_logprobs, logprobs def filter(self, indices): if self.watermark_processor is not None: @@ -342,6 +341,28 @@ class HeterogeneousSampling: return self +def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor): + """Find the top n most likely tokens for a batch of generations.""" + top_n_tokens = torch.tensor(top_n_tokens) + if top_n_tokens.min() == 0: + return [], [] + + # Ensure top_n doesn't exceed vocab size + top_n_tokens = torch.clip(top_n_tokens, max=logprobs.size(-1)) + + # Take the topk using the highest requested top_n_tokens. + top_k = torch.topk(logprobs, k=max(top_n_tokens), dim=1, sorted=True) + + # Move all digits into a list at once to prevent multiple GPU syncs + top_indices = top_k.indices.tolist() + top_values = top_k.values.tolist() + + return ( + [idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)], + [vals[:n] for vals, n in zip(top_values, top_n_tokens)], + ) + + def get_top_tokens( requested_n: int, logprobs, @@ -354,7 +375,8 @@ def get_top_tokens( if not requested_n: return [] - flat_scores = logprobs[-1] + # Dirty hack + flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1] # Ensure top_n doesn't exceed vocab size top_n = min(requested_n, flat_scores.size(-1)) # Get nth highest value, ensure it's not -inf (for example if top_n > top_k) @@ -368,14 +390,16 @@ def get_top_tokens( for tid_tensor in top_n_indices: tid_item = tid_tensor[0].item() token_text, _, _ = decode_fn( - torch.cat([decoder_input_ids, tid_tensor]), + torch.cat([decoder_input_ids, tid_tensor]) + if isinstance(decoder_input_ids, torch.Tensor) + else decoder_input_ids + [tid_item], prefix_offset, read_offset, ) top_tokens.append( TopToken( token_id=tid_item, - token_logprob=logprobs[-1, tid_tensor], + token_logprob=flat_scores[tid_tensor], token_text=token_text, token_is_special=tid_item in special_tokens, )