diff --git a/proto/generate.proto b/proto/generate.proto index c873e661..a0fb48b0 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -135,43 +135,27 @@ message GeneratedText { optional uint64 seed = 4; } -message PrefillTokens { - /// Prefill Token IDs +message Tokens { + /// Token IDs repeated uint32 ids = 1; - /// Prefill Logprobs + /// Logprobs repeated float logprobs = 2; - /// Prefill tokens + /// tokens repeated string texts = 3; -} - -message TopTokens { - /// Top Token IDs - repeated uint32 ids = 1; - /// Top Logprobs - repeated float logprobs = 2; - /// Top Token Texts - repeated string texts = 3; - /// If the tokens are special - repeated bool is_special = 6; + /// special + repeated bool is_special = 4; } message Generation { /// Request ID uint64 request_id = 1; /// Prefill tokens (optional) - PrefillTokens prefill_tokens = 2; - /// Token ID - uint32 token_id = 3; - /// Logprob - float token_logprob = 4; - /// Text - string token_text = 5; - /// Is it a special token - bool token_is_special = 6; + Tokens prefill_tokens = 2; + Tokens tokens = 3; /// Complete generated text - optional GeneratedText generated_text = 7; + optional GeneratedText generated_text = 4; /// Top tokens - TopTokens top_tokens = 8; + repeated Tokens top_tokens = 5; } message FilterBatchRequest { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index f334be21..0bb61568 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -10,7 +10,7 @@ pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - PrefillTokens, Request, StoppingCriteriaParameters, + Tokens, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/infer.rs b/router/src/infer.rs index aa6dc664..050005f1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -9,7 +9,7 @@ use std::sync::{ Arc, }; use text_generation_client::{ - Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, + Batch, CachedBatch, ClientError, GeneratedText, Generation, Tokens, ShardedClient, }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; @@ -167,21 +167,21 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Intermediate { token, top_tokens } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); + InferStreamResponse::Intermediate { tokens, top_tokens } => { + result_tokens.extend(tokens); + result_top_tokens.extend(top_tokens); } // Final message // Set return values InferStreamResponse::End { - token, + tokens, generated_text, start, queued, top_tokens, } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); + result_tokens.extend(tokens); + result_top_tokens.extend(top_tokens); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) @@ -523,31 +523,41 @@ fn send_responses( } // Create last Token - let token = Token { - id: generation.token_id, - text: generation.token_text, - logprob: generation.token_logprob, - special: generation.token_is_special, - }; - - // generation.top_tokens - - let mut top_tokens = Vec::new(); - if let Some(top_tokens_) = generation.top_tokens { - top_tokens.extend( - top_tokens_ - .ids - .into_iter() - .zip(top_tokens_.logprobs.into_iter()) - .zip(top_tokens_.texts.into_iter()) - .zip(top_tokens_.is_special.into_iter()) + let tokens: Vec = if let Some(tokens_) = generation.tokens{ + tokens_.ids.into_iter() + .zip(tokens_.logprobs.into_iter()) + .zip(tokens_.texts.into_iter()) + .zip(tokens_.is_special.into_iter()) .map(|(((id, logprob), text), special)| Token { id, text, logprob, special, - }), - ) + }).collect() + }else{ + vec![] + }; + + // generation.top_tokens + + let mut top_tokens = Vec::new(); + for top_tokens_ in generation.top_tokens{ + let mut local_top_tokens = Vec::new(); + local_top_tokens.extend( + top_tokens_ + .ids + .into_iter() + .zip(top_tokens_.logprobs.into_iter()) + .zip(top_tokens_.texts.into_iter()) + .zip(top_tokens_.is_special.into_iter()) + .map(|(((id, logprob), text), special)| Token { + id, + text, + logprob, + special, + }), + ); + top_tokens.push(local_top_tokens); } if let Some(generated_text) = generation.generated_text { @@ -555,7 +565,7 @@ fn send_responses( stopped = true; // Send message entry.response_tx.send(Ok(InferStreamResponse::End { - token, + tokens, top_tokens, generated_text, queued: entry.queue_time, @@ -565,7 +575,7 @@ fn send_responses( // Send message entry .response_tx - .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + .send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?; } Ok(stopped) } @@ -591,16 +601,16 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message - Prefill(PrefillTokens), + Prefill(Tokens), // Intermediate messages Intermediate { - token: Token, - top_tokens: Vec, + tokens: Vec, + top_tokens: Vec>, }, // Last message End { - token: Token, - top_tokens: Vec, + tokens: Vec, + top_tokens: Vec>, generated_text: GeneratedText, start: Instant, queued: Instant, diff --git a/router/src/lib.rs b/router/src/lib.rs index b547dc15..cbc0b478 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -279,9 +279,10 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { - pub token: Token, + pub tokens: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] - pub top_tokens: Vec, + pub top_tokens: Vec>, + pub text: String, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/server.rs b/router/src/server.rs index f254afd8..a26dafd1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -388,14 +388,15 @@ async fn generate_stream( InferStreamResponse::Prefill(_) => {} // Yield event for every new token InferStreamResponse::Intermediate{ - token, + tokens, top_tokens, } => { - tracing::debug!(parent: &span, "Token: {:?}", token); + tracing::debug!(parent: &span, "Tokens: {:?}", tokens); // StreamResponse let stream_token = StreamResponse { - token, + tokens, + text, top_tokens, generated_text: None, details: None, @@ -405,7 +406,8 @@ async fn generate_stream( } // Yield event for last token and compute timings InferStreamResponse::End { - token, + tokens, + text, generated_text, start, queued, @@ -457,8 +459,9 @@ async fn generate_stream( tracing::info!(parent: &span, "Success"); let stream_token = StreamResponse { - token, + tokens, top_tokens, + text generated_text: Some(output_text), details }; diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8056a8ec..c571a022 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -10,10 +10,9 @@ from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, - TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -676,8 +675,8 @@ class CausalLM(Model): clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts + prefill_tokens = Tokens( + prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] ) else: prefill_tokens = None @@ -691,7 +690,7 @@ class CausalLM(Model): special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -703,10 +702,12 @@ class CausalLM(Model): generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3b3aa400..52679be8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -14,10 +14,9 @@ from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, - TopTokens, ) from text_generation_server.models.cache_manager import ( get_cache_manager, @@ -952,14 +951,17 @@ class FlashCausalLM(Model): # Append next token to all tokens _next_token_ids = next_token_ids[index: index+n_accepted_ids] _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] - all_input_ids.extend(_next_token_ids) - # Generated token - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) + next_token_texts = [] + for j in range(index, index + n_accepted_ids): + # Generated token + all_input_ids.append(next_token_ids[j]) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) # Evaluate stopping criteria @@ -1013,8 +1015,8 @@ class FlashCausalLM(Model): clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, request_prefill_logprobs, prefill_texts + prefill_tokens = Tokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] ) else: prefill_tokens = None @@ -1028,7 +1030,7 @@ class FlashCausalLM(Model): special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -1037,16 +1039,15 @@ class FlashCausalLM(Model): else: top_tokens = None - next_token_ids = _next_token_ids[0] - next_token_logprob = _next_token_logprobs[0] - generation = Generation( request.id, prefill_tokens, - next_token_id, - next_token_logprob, - next_token_text, - next_token_id in self.all_special_ids, + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index dcad1fa9..2f4bb139 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -20,7 +20,7 @@ from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( Batch, - PrefillTokens, + Tokens, Generation, GeneratedText, ) @@ -791,8 +791,8 @@ class IdeficsCausalLM(Model): clean_up_tokenization_spaces=False, skip_special_tokens=False, ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts + prefill_tokens = Tokens( + prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] ) else: prefill_tokens = None @@ -802,10 +802,12 @@ class IdeficsCausalLM(Model): generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d4d3cd19..30a0a20d 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -11,8 +11,7 @@ from text_generation_server.models.types import ( GeneratedText, Batch, Generation, - PrefillTokens, - TopTokens, + Tokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -733,10 +732,11 @@ class Seq2SeqLM(Model): # Prefill if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: - prefill_tokens = PrefillTokens( + prefill_tokens = Tokens( [self.tokenizer.bos_token_id], [float("nan")], [self.tokenizer.bos_token], + [false] ) else: prefill_tokens = None @@ -750,7 +750,7 @@ class Seq2SeqLM(Model): special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] - top_tokens = TopTokens( + top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, @@ -762,10 +762,12 @@ class Seq2SeqLM(Model): generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, - next_token_logprob, - next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), generated_text, top_tokens, ) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 0e27680d..87c03d63 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -58,33 +58,15 @@ class GeneratedText: @dataclass -class PrefillTokens: - token_ids: List[int] - logprobs: List[float] - texts: List[str] - - def to_pb(self) -> generate_pb2.PrefillTokens: - return generate_pb2.PrefillTokens( - ids=self.token_ids, logprobs=self.logprobs, texts=self.texts - ) - - def __len__(self): - return len(self.token_ids) - - -@dataclass -class TopTokens: +class Tokens: token_ids: List[int] logprobs: List[float] texts: List[str] is_special: List[bool] - def to_pb(self) -> generate_pb2.TopTokens: - return generate_pb2.TopTokens( - ids=self.token_ids, - logprobs=self.logprobs, - texts=self.texts, - is_special=self.is_special, + def to_pb(self) -> generate_pb2.Tokens: + return generate_pb2.Tokens( + ids=self.token_ids, logprobs=self.logprobs, texts=self.texts, is_special=self.is_special ) def __len__(self): @@ -94,14 +76,11 @@ class TopTokens: @dataclass class Generation: request_id: int - prefill_tokens: Optional[PrefillTokens] - token_id: int - token_logprob: float - token_text: str - token_is_special: bool + prefill_tokens: Optional[Tokens] + tokens: Tokens generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. - top_tokens: Optional[TopTokens] + top_tokens: Optional[List[Tokens]] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -109,10 +88,7 @@ class Generation: prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, - token_id=self.token_id, - token_logprob=self.token_logprob, - token_text=self.token_text, - token_is_special=self.token_is_special, + tokens=self.tokens.to_pb(), generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,