From 8515999b1dd9c19155ee40ed1f45fcf3c799e727 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Wed, 26 Jul 2023 15:12:57 +0000 Subject: [PATCH] Implement top-n-tokens for all models --- .../models/causal_lm.py | 30 ++++++++++++ .../models/flash_causal_lm.py | 29 ++++-------- server/text_generation_server/models/model.py | 44 +++++++++++++++-- .../models/seq2seq_lm.py | 42 ++++++++++++----- server/text_generation_server/utils/tokens.py | 47 ------------------- 5 files changed, 110 insertions(+), 82 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 17b8aa83..f3eee175 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.tokens import batch_top_tokens import torch import inspect @@ -42,6 +43,7 @@ class CausalLMBatch(Batch): # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] # Metadata used for padding max_input_length: int @@ -72,6 +74,7 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] prefix_offsets = [] read_offsets = [] requests_idx_mapping = {} @@ -88,6 +91,7 @@ class CausalLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -138,6 +142,7 @@ class CausalLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, @@ -163,6 +168,7 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] total_remaining_decode_tokens = 0 new_padding_right_offset = 0 @@ -184,6 +190,7 @@ class CausalLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -235,6 +242,7 @@ class CausalLMBatch(Batch): self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens @@ -262,6 +270,7 @@ class CausalLMBatch(Batch): all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_tokens = 0 # Batch tensors @@ -281,6 +290,7 @@ class CausalLMBatch(Batch): all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -438,6 +448,7 @@ class CausalLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, @@ -549,6 +560,10 @@ class CausalLM(Model): generations: List[Generation] = [] stopped = True + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, torch.softmax(logits[:, -1], -1) + ) + # Zipped iterator iterator = zip( batch.requests, @@ -559,6 +574,9 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -571,7 +589,19 @@ class CausalLM(Model): next_token_chooser, stopping_criteria, all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): + top_tokens = self.decode_top_tokens( + input_ids=all_input_ids.view(1, -1).tolist(), + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + # Select next token next_token_id, logprobs = next_token_chooser( all_input_ids.view(1, -1), logits[-1:, :] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 57f95603..dc62955f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,6 +1,6 @@ import math import itertools -from text_generation_server.utils.tokens import get_top_tokens, batch_top_tokens +from text_generation_server.utils.tokens import batch_top_tokens import torch import torch.distributed @@ -972,25 +972,14 @@ class FlashCausalLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): - top_tokens = [] - - if top_n_tokens > 0: - top_token_texts = self.decode_tokens( - input_ids=all_input_ids, - new_input_ids=top_token_ids, - prefix_offset=prefix_offset, - read_offset=read_offset, - ) - for token_id, (top_token_text, _, _), token_logprob in zip(top_token_ids, top_token_texts, top_token_logprobs): - tok_itm = token_id - top_tokens.append( - TopToken( - token_id=token_id, - token_logprob=token_logprob, - token_text=top_token_text, - token_is_special=tok_itm in self.all_special_ids, - ) - ) + top_tokens = self.decode_top_tokens( + input_ids=all_input_ids, + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) # Append next token to all tokens all_input_ids.append(next_token_id) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 0ccb65ab..94d9306f 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase, PretrainedConfig -from text_generation_server.models.types import Batch, Generation +from text_generation_server.models.types import Batch, Generation, TopToken from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) @@ -101,8 +101,12 @@ class Model(ABC): input_ids[prefix_offset:read_offset], skip_special_tokens=False ) - new_sequences = [input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids] - new_texts = self.tokenizer.batch_decode(new_sequences, skip_special_tokens=False) + new_sequences = [ + input_ids[prefix_offset:] + [new_id] for new_id in new_input_ids + ] + new_texts = self.tokenizer.batch_decode( + new_sequences, skip_special_tokens=False + ) results = [] for new_text in new_texts: @@ -117,6 +121,40 @@ class Model(ABC): results.append(("", prefix_offset, read_offset)) return results + def decode_top_tokens( + self, + input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + prefix_offset, + read_offset, + ): + if top_n_tokens == 0: + return [] + + top_token_texts = self.decode_tokens( + input_ids=input_ids, + new_input_ids=top_token_ids, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + + top_tokens = [] + for token_id, (top_token_text, _, _), token_logprob in zip( + top_token_ids, top_token_texts, top_token_logprobs + ): + tok_itm = token_id + top_tokens.append( + TopToken( + token_id=token_id, + token_logprob=token_logprob, + token_text=top_token_text, + token_is_special=tok_itm in self.all_special_ids, + ) + ) + return top_tokens + def check_initialized(self): uninitialized_parameters = [] for n, p in self.model.named_parameters(): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index a9da647d..f215e632 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.tokens import get_top_tokens +from text_generation_server.utils.tokens import batch_top_tokens import torch from dataclasses import dataclass @@ -49,6 +49,7 @@ class Seq2SeqLMBatch(Batch): # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] # Metadata used for padding max_input_length: int @@ -79,7 +80,7 @@ class Seq2SeqLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - + top_n_tokens = [] decoder_input_lengths = [] prefix_offsets = [] read_offsets = [] @@ -98,6 +99,7 @@ class Seq2SeqLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -147,6 +149,7 @@ class Seq2SeqLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, @@ -174,6 +177,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_input_length = 0 max_decoder_input_length = 0 @@ -205,6 +209,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -255,6 +260,7 @@ class Seq2SeqLMBatch(Batch): self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens self.max_input_length = max_input_length self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset @@ -290,6 +296,7 @@ class Seq2SeqLMBatch(Batch): read_offsets = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_tokens = 0 # Batch tensors @@ -313,6 +320,7 @@ class Seq2SeqLMBatch(Batch): read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -489,6 +497,7 @@ class Seq2SeqLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, @@ -614,6 +623,10 @@ class Seq2SeqLM(Model): batch.past_key_values, ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, torch.softmax(logits[:, -1], -1) + ) + # Finished requests generations: List[Generation] = [] stopped = True @@ -629,6 +642,9 @@ class Seq2SeqLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_decoder_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -642,22 +658,24 @@ class Seq2SeqLM(Model): next_token_chooser, stopping_criteria, all_decoder_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): + top_tokens = self.decode_top_tokens( + input_ids=all_decoder_input_ids.view(1, -1).tolist(), + top_n_tokens=top_n_tokens, + top_token_ids=top_token_ids, + top_token_logprobs=top_token_logprobs, + prefix_offset=prefix_offset, + read_offset=read_offset, + ) + # Select next token next_token_id, logprobs = next_token_chooser( all_decoder_input_ids.view(1, -1), logits[-1:, :] ) - top_tokens = get_top_tokens( - request.top_n_tokens, - logprobs, - self.all_special_ids, - self.decode_token, - all_decoder_input_ids, - prefix_offset, - read_offset, - ) - # Append next token to decoder tokens all_decoder_input_ids = torch.cat( [all_decoder_input_ids, next_token_id.squeeze(1)] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 499ff054..db7f9510 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -8,7 +8,6 @@ from transformers import ( ) from text_generation_server.pb import generate_pb2 -from text_generation_server.models.types import TopToken from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.logits_process import ( @@ -361,49 +360,3 @@ def batch_top_tokens(top_n_tokens: torch.Tensor, logprobs: torch.Tensor): [idxs[:n] for idxs, n in zip(top_indices, top_n_tokens)], [vals[:n] for vals, n in zip(top_values, top_n_tokens)], ) - - -def get_top_tokens( - requested_n: int, - logprobs, - special_tokens: List[int], - decode_fn: Callable[[List[int], int, int], str], - decoder_input_ids: List[int], - prefix_offset: int, - read_offset: int, -) -> List[TopToken]: - if not requested_n: - return [] - - # Dirty hack - flat_scores = logprobs if len(logprobs.shape) == 1 else logprobs[-1] - # Ensure top_n doesn't exceed vocab size - top_n = min(requested_n, flat_scores.size(-1)) - # Get nth highest value, ensure it's not -inf (for example if top_n > top_k) - nth_highest = torch.topk(flat_scores, top_n)[0][-1] - if nth_highest == -float("inf"): - nth_highest = torch.finfo(flat_scores.dtype).min - # Get indices (token ids) of all scores >= nth highest value, - # cap length at 4 * top_n as a precaution - top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)] - top_tokens = [] - for tid_tensor in top_n_indices: - tid_item = tid_tensor[0].item() - token_text, _, _ = decode_fn( - torch.cat([decoder_input_ids, tid_tensor]) - if isinstance(decoder_input_ids, torch.Tensor) - else decoder_input_ids + [tid_item], - prefix_offset, - read_offset, - ) - top_tokens.append( - TopToken( - token_id=tid_item, - token_logprob=flat_scores[tid_tensor], - token_text=token_text, - token_is_special=tid_item in special_tokens, - ) - ) - - top_tokens.sort(reverse=True) - return top_tokens