From 8a4d2076a6ddef2b9ac4b69aa006c802876cb29c Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Fri, 14 Jul 2023 19:48:15 +0000 Subject: [PATCH] Add WIP support for returning top tokens Initial support returning the most probable tokens. Note that it is currently only implemented for seq-to-seq models. It is also always enabled, regardless of whether it is used or not. --- benchmark/src/generation.rs | 7 +++ clients/python/text_generation/types.py | 9 ++++ proto/generate.proto | 15 ++++++ router/client/src/client.rs | 1 + router/src/health.rs | 1 + router/src/infer.rs | 37 +++++++++++++-- router/src/lib.rs | 8 ++++ router/src/queue.rs | 4 ++ router/src/server.rs | 17 ++++++- router/src/validation.rs | 4 ++ .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 1 + .../models/seq2seq_lm.py | 12 +++++ server/text_generation_server/models/types.py | 30 ++++++++++++ server/text_generation_server/utils/tokens.py | 46 ++++++++++++++++++- 15 files changed, 187 insertions(+), 6 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index c72d31d3..ca844065 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -73,6 +73,9 @@ async fn generate_runs( // Create a dummy sequence let sequence = create_sequence(sequence_length, tokenizer); + // TODO: Implement top_n_tokens + let top_n_tokens= 0; + for b in batch_size { // Warmups on batch size for _ in 0..warmups { @@ -82,6 +85,7 @@ async fn generate_runs( b, decode_length, parameters.clone(), + top_n_tokens, &mut client, ) .await?; @@ -97,6 +101,7 @@ async fn generate_runs( b, decode_length, parameters.clone(), + top_n_tokens, &mut client, ) .await?; @@ -130,6 +135,7 @@ async fn prefill( batch_size: u32, decode_length: u32, parameters: NextTokenChooserParameters, + top_n_tokens: u32, client: &mut ShardedClient, ) -> Result<(Prefill, CachedBatch), ClientError> { // Create requests @@ -145,6 +151,7 @@ async fn prefill( stop_sequences: vec![], ignore_eos_token: true, // Will not stop even if a eos token is generated }), + top_n_tokens: top_n_tokens, }) .collect(); diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b63..52fdaf53 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -179,6 +179,9 @@ class BestOfSequence(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + # TODO: Make this optional? + top_tokens: List[List[Token]] # `generate` details @@ -193,6 +196,9 @@ class Details(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + # TODO: Make this optional? + top_tokens: List[List[Token]] # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] @@ -219,6 +225,9 @@ class StreamDetails(BaseModel): class StreamResponse(BaseModel): # Generated token token: Token + # Most likely tokens + # TODO: Make this optional? + top_tokens: List[Token] # Complete generated text # Only available when the generation is finished generated_text: Optional[str] diff --git a/proto/generate.proto b/proto/generate.proto index 57d79bca..89ca6e08 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -91,6 +91,8 @@ message Request { StoppingCriteriaParameters stopping_parameters = 5; /// Return prefill logprobs bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; } message Batch { @@ -141,6 +143,17 @@ message PrefillTokens { repeated string texts = 3; } +message TopToken { + /// Token ID + uint32 token_id = 3; + /// Logprob + float token_logprob = 4; + /// Text + string token_text = 5; + /// Is it a special token + bool token_is_special = 6; +} + message Generation { /// Request ID uint64 request_id = 1; @@ -156,6 +169,8 @@ message Generation { bool token_is_special = 6; /// Complete generated text optional GeneratedText generated_text = 7; + /// Top tokens + repeated TopToken top_tokens = 8; } message FilterBatchRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7753f307..d427d3a4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -131,6 +131,7 @@ impl Client { ignore_eos_token: false, }), prefill_logprobs: true, + top_n_tokens: 20, }); n_tokens += max_input_length; } diff --git a/router/src/health.rs b/router/src/health.rs index a3cacdcd..a6b6baa1 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -50,6 +50,7 @@ impl Health { stop_sequences: vec![], ignore_eos_token: false, }), + top_n_tokens: 0 }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/infer.rs b/router/src/infer.rs index 188ddc64..b88dfb2a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -138,12 +138,15 @@ impl Infer { &self, request: GenerateRequest, ) -> Result { + let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); + // Create stream and keep semaphore permit as long as generate lives let (_permit, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; @@ -164,7 +167,13 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::Intermediate{ + token, + top_tokens, + } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } // Final message // Set return values InferStreamResponse::End { @@ -172,8 +181,11 @@ impl Infer { generated_text, start, queued, + top_tokens, + } => { result_tokens.push(token); + result_top_tokens.push(top_tokens); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) @@ -191,6 +203,7 @@ impl Infer { generated_text, queued, start, + top_tokens: if use_top_tokens { result_top_tokens } else { Vec::new() }, }) } else { let err = InferError::IncompleteGeneration; @@ -520,6 +533,18 @@ fn send_responses( special: generation.token_is_special, }; + + // generation.top_tokens + let mut top_tokens = Vec::new(); + for top_token in generation.top_tokens { + top_tokens.push(Token{ + id: top_token.token_id, + text: top_token.token_text, + logprob: top_token.token_logprob, + special: top_token.token_is_special, + }) + } + if let Some(generated_text) = generation.generated_text { // Generation has ended stopped = true; @@ -527,6 +552,7 @@ fn send_responses( entry.response_tx.send_timeout( Ok(InferStreamResponse::End { token, + top_tokens, generated_text, queued: entry.queue_time, start: entry.batch_time.unwrap(), @@ -536,7 +562,7 @@ fn send_responses( } else { // Send message entry.response_tx.send_timeout( - Ok(InferStreamResponse::Token(token)), + Ok(InferStreamResponse::Intermediate{token, top_tokens}), Duration::from_millis(10), )?; } @@ -566,10 +592,14 @@ pub(crate) enum InferStreamResponse { // Optional first message Prefill(PrefillTokens), // Intermediate messages - Token(Token), + Intermediate { + token: Token, + top_tokens: Vec, + }, // Last message End { token: Token, + top_tokens: Vec, generated_text: GeneratedText, start: Instant, queued: Instant, @@ -583,6 +613,7 @@ pub(crate) struct InferResponse { pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, } #[derive(Debug, Error)] diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a11..6f1d4c8f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters { example = "null" )] pub seed: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] + pub top_n_tokens: Option, } fn default_max_new_tokens() -> u32 { @@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters { details: false, decoder_input_details: false, seed: None, + top_n_tokens: None, } } @@ -235,6 +239,7 @@ pub(crate) struct BestOfSequence { pub seed: Option, pub prefill: Vec, pub tokens: Vec, + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -249,6 +254,7 @@ pub(crate) struct Details { pub tokens: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub best_of_sequences: Option>, + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -272,6 +278,8 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, + #[schema(nullable = true, default = "null")] + pub top_tokens: Option>, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/queue.rs b/router/src/queue.rs index 2d8d6d1c..aeb668e3 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -235,6 +235,9 @@ impl State { truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), + // TODO: Actually fill this from the request + top_n_tokens: entry.request.top_n_tokens, + }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -328,6 +331,7 @@ mod tests { max_new_tokens: 1, stop_sequences: vec![], }, + top_n_tokens: 0, }, response_tx, span: info_span!("entry"), diff --git a/router/src/server.rs b/router/src/server.rs index e609821c..84de8bd3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -158,7 +158,7 @@ async fn generate( add_prompt = Some(req.inputs.clone()); } - let details = req.parameters.details || req.parameters.decoder_input_details; + let details: bool = req.parameters.details || req.parameters.decoder_input_details; // Inference let (response, best_of_responses) = match req.parameters.best_of { @@ -191,6 +191,7 @@ async fn generate( generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, + top_tokens: response.top_tokens, seed: response.generated_text.seed, } }) @@ -204,6 +205,7 @@ async fn generate( tokens: response.tokens, seed: response.generated_text.seed, best_of_sequences, + top_tokens: response.top_tokens, }) } false => None, @@ -374,7 +376,12 @@ async fn generate_stream( tracing::error!("{err}"); yield Ok(Event::from(err)); } else { +<<<<<<< HEAD match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { +======= + let top_n_tokens = req.0.parameters.top_n_tokens; + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { +>>>>>>> 7c014c7 (Add WIP support for returning top tokens) // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { // Server-Sent Event stream @@ -385,12 +392,16 @@ async fn generate_stream( // Prefill is ignored InferStreamResponse::Prefill(_) => {} // Yield event for every new token - InferStreamResponse::Token(token) => { + InferStreamResponse::Intermediate{ + token, + top_tokens, + } => { tracing::debug!(parent: &span, "Token: {:?}", token); // StreamResponse let stream_token = StreamResponse { token, + top_tokens: top_n_tokens.and(Some(top_tokens)), generated_text: None, details: None, }; @@ -403,6 +414,7 @@ async fn generate_stream( generated_text, start, queued, + top_tokens, } => { // Token details let details = match details { @@ -451,6 +463,7 @@ async fn generate_stream( let stream_token = StreamResponse { token, + top_tokens:top_n_tokens.and(Some(top_tokens)), generated_text: Some(output_text), details }; diff --git a/router/src/validation.rs b/router/src/validation.rs index f967361f..5999e37b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -142,6 +142,8 @@ impl Validation { seed, watermark, decoder_input_details, + // TODO: Validate top_n_tokens + top_n_tokens, .. } = request.parameters; @@ -263,6 +265,7 @@ impl Validation { truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, stopping_parameters, + top_n_tokens: top_n_tokens.unwrap_or(0), }) } @@ -336,6 +339,7 @@ pub(crate) struct ValidGenerateRequest { pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, + pub top_n_tokens: u32, } #[derive(Error, Debug)] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cbdf4808..17b8aa83 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -645,6 +645,7 @@ class CausalLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7de51358..26bce585 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1013,6 +1013,7 @@ class FlashCausalLM(Model): next_token_text, next_token_id in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 9e5c21d1..a9da647d 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.tokens import get_top_tokens import torch from dataclasses import dataclass @@ -647,6 +648,16 @@ class Seq2SeqLM(Model): all_decoder_input_ids.view(1, -1), logits[-1:, :] ) + top_tokens = get_top_tokens( + request.top_n_tokens, + logprobs, + self.all_special_ids, + self.decode_token, + all_decoder_input_ids, + prefix_offset, + read_offset, + ) + # Append next token to decoder tokens all_decoder_input_ids = torch.cat( [all_decoder_input_ids, next_token_id.squeeze(1)] @@ -706,6 +717,7 @@ class Seq2SeqLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 28ca8147..822335ff 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -1,3 +1,4 @@ +from functools import total_ordering import torch from abc import ABC, abstractmethod @@ -71,6 +72,30 @@ class PrefillTokens: return len(self.token_ids) +@dataclass(eq=True) +@total_ordering +class TopToken: + token_id: int + token_logprob: float + token_text: str + token_is_special: bool + + def __gt__(self, other): + # We tiebreak equal logprobs with the _lower_ token_id to align with + # greedy ordering (torch.argmax) + return self.token_logprob > other.token_logprob or ( + self.token_logprob == other.token_logprob and self.token_id < other.token_id + ) + + def to_pb(self) -> generate_pb2.TopToken: + return generate_pb2.TopToken( + token_id=self.token_id, + token_logprob=self.token_logprob, + token_text=self.token_text, + token_is_special=self.token_is_special, + ) + + @dataclass class Generation: request_id: int @@ -80,6 +105,8 @@ class Generation: token_text: str token_is_special: bool generated_text: Optional[GeneratedText] + # Optional for now, since it's not yet supported for every model. + top_tokens: Optional[List[TopToken]] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -94,4 +121,7 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, + top_tokens=[toptoken.to_pb() for toptoken in self.top_tokens] + if self.top_tokens + else None, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b83af591..a50370de 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,13 +1,14 @@ import re +from typing import Callable, List, Tuple, Optional import torch from transformers import ( RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, ) -from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 +from text_generation_server.models.types import TopToken from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.logits_process import ( @@ -339,3 +340,46 @@ class HeterogeneousSampling: self.greedy_indices = new_greedy_indices self.sampling_mapping = new_sampling_mapping return self + + +def get_top_tokens( + requested_n: int, + logprobs, + special_tokens: List[int], + decode_fn: Callable[[List[int], int, int], str], + decoder_input_ids: List[int], + prefix_offset: int, + read_offset: int, +) -> List[TopToken]: + if not requested_n: + return [] + + flat_scores = logprobs[-1] + # Ensure top_n doesn't exceed vocab size + top_n = min(requested_n, flat_scores.size(-1)) + # Get nth highest value, ensure it's not -inf (for example if top_n > top_k) + nth_highest = torch.topk(flat_scores, top_n)[0][-1] + if nth_highest == -float("inf"): + nth_highest = torch.finfo(flat_scores.dtype).min + # Get indices (token ids) of all scores >= nth highest value, + # cap length at 4 * top_n as a precaution + top_n_indices = (flat_scores >= nth_highest).nonzero()[: (top_n * 4)] + top_tokens = [] + for tid_tensor in top_n_indices: + tid_item = tid_tensor[0].item() + token_text, _, _ = decode_fn( + torch.cat([decoder_input_ids, tid_tensor]), + prefix_offset, + read_offset, + ) + top_tokens.append( + TopToken( + token_id=tid_item, + token_logprob=logprobs[-1, tid_tensor], + token_text=token_text, + token_is_special=tid_item in special_tokens, + ) + ) + + top_tokens.sort(reverse=True) + return top_tokens