From be481a47998a16d67549c966c8b2be45a405192f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 5 Dec 2023 15:21:42 +0000 Subject: [PATCH] Address comments. --- proto/generate.proto | 3 +- router/client/src/client.rs | 4 +- router/client/src/lib.rs | 6 +- router/src/infer.rs | 110 +++++++++--------- router/src/server.rs | 83 ++++++------- .../text_generation_server/models/__init__.py | 26 +++-- .../models/flash_causal_lm.py | 16 +-- .../text_generation_server/utils/speculate.py | 12 ++ 8 files changed, 128 insertions(+), 132 deletions(-) create mode 100644 server/text_generation_server/utils/speculate.py diff --git a/proto/generate.proto b/proto/generate.proto index a0fb48b0..659c62ff 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package generate.v1; +package generate.v2; service TextGenerationService { /// Model Info @@ -32,6 +32,7 @@ message InfoResponse { string dtype = 2; string device_type = 3; optional uint32 window_size = 4; + optional uint32 speculate = 5; } /// Empty request diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 341e70fd..1560f19c 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,6 +1,6 @@ /// Single shard Client -use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; -use crate::pb::generate::v1::*; +use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; +use crate::pb::generate::v2::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use std::cmp::min; diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 1ea5e365..c38b931b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,9 +6,9 @@ mod pb; mod sharded_client; pub use client::Client; -pub use pb::generate::v1::HealthResponse; -pub use pb::generate::v1::InfoResponse as ShardInfo; -pub use pb::generate::v1::{ +pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::InfoResponse as ShardInfo; +pub use pb::generate::v2::{ Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; diff --git a/router/src/infer.rs b/router/src/infer.rs index 53d71f89..224c8ae8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -167,21 +167,21 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Intermediate { tokens, top_tokens } => { - result_tokens.extend(tokens); - result_top_tokens.extend(top_tokens); + InferStreamResponse::Intermediate { token, top_tokens } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); } // Final message // Set return values InferStreamResponse::End { - tokens, + token, generated_text, start, queued, top_tokens, } => { - result_tokens.extend(tokens); - result_top_tokens.extend(top_tokens); + result_tokens.push(token); + result_top_tokens.push(top_tokens); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) @@ -515,7 +515,6 @@ fn send_responses( let mut stopped = false; - tracing::info!("Generation: {:?}", generation); if let Some(prefill_tokens) = generation.prefill_tokens { // Send message entry @@ -524,68 +523,63 @@ fn send_responses( } // Create last Token - let tokens: Vec = if let Some(tokens_) = generation.tokens { - tokens_ + let tokens_ = generation.tokens.expect("Non empty tokens in generation"); + let n = tokens_.ids.len(); + metrics::histogram!( + "tgi_request_skipped_tokens", + (n - 1) as f64 + ); + for (i, (((id, logprob), text), special)) in tokens_ .ids .into_iter() .zip(tokens_.logprobs.into_iter()) .zip(tokens_.texts.into_iter()) - .zip(tokens_.is_special.into_iter()) - .map(|(((id, logprob), text), special)| Token { + .zip(tokens_.is_special.into_iter()).enumerate() { + let token = Token { id, text, logprob, special, - }) - .collect() - } else { - vec![] - }; - - // generation.top_tokens - - let mut top_tokens = Vec::new(); - for top_tokens_ in generation.top_tokens { - let mut local_top_tokens = Vec::new(); - local_top_tokens.extend( + }; + let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i){ top_tokens_ .ids - .into_iter() - .zip(top_tokens_.logprobs.into_iter()) - .zip(top_tokens_.texts.into_iter()) - .zip(top_tokens_.is_special.into_iter()) - .map(|(((id, logprob), text), special)| Token { + .iter() + .zip(top_tokens_.logprobs.iter()) + .zip(top_tokens_.texts.iter()) + .zip(top_tokens_.is_special.iter()) + .map(|(((&id, &logprob), text), &special)| Token { id, - text, + text: text.to_string(), logprob, special, - }), - ); - top_tokens.push(local_top_tokens); - } - // Force top_tokens to be the same size as tokens, both are going to be - // zipped later - if top_tokens.len() != tokens.len() { - top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect(); + }).collect() + }else{ + vec![] + + }; + match (&generation.generated_text, i){ + (Some(generated_text), i) if i == n - 1 => { + // Generation has ended + stopped = true; + // Send message + entry.response_tx.send(Ok(InferStreamResponse::End { + token, + top_tokens, + generated_text: generated_text.clone(), + queued: entry.queue_time, + start: entry.batch_time.unwrap(), + }))?; + } + _ => { + // Send message + entry + .response_tx + .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?; + } + } } - if let Some(generated_text) = generation.generated_text { - // Generation has ended - stopped = true; - // Send message - entry.response_tx.send(Ok(InferStreamResponse::End { - tokens, - top_tokens, - generated_text, - queued: entry.queue_time, - start: entry.batch_time.unwrap(), - }))?; - } else { - // Send message - entry - .response_tx - .send(Ok(InferStreamResponse::Intermediate { tokens, top_tokens }))?; - } Ok(stopped) } @@ -613,13 +607,13 @@ pub(crate) enum InferStreamResponse { Prefill(Tokens), // Intermediate messages Intermediate { - tokens: Vec, - top_tokens: Vec>, + token: Token, + top_tokens: Vec, }, // Last message End { - tokens: Vec, - top_tokens: Vec>, + token: Token, + top_tokens: Vec, generated_text: GeneratedText, start: Instant, queued: Instant, diff --git a/router/src/server.rs b/router/src/server.rs index 789b47e4..f254afd8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -388,31 +388,39 @@ async fn generate_stream( InferStreamResponse::Prefill(_) => {} // Yield event for every new token InferStreamResponse::Intermediate{ - tokens, + token, top_tokens, } => { + tracing::debug!(parent: &span, "Token: {:?}", token); - for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) { - // StreamResponse - let stream_token = StreamResponse { - token, - top_tokens, - generated_text: None, - details: None, - }; + // StreamResponse + let stream_token = StreamResponse { + token, + top_tokens, + generated_text: None, + details: None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()); - } + yield Ok(Event::default().json_data(stream_token).unwrap()) } // Yield event for last token and compute timings InferStreamResponse::End { - tokens, + token, generated_text, start, queued, top_tokens, } => { // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; + // Timings let total_time = start_time.elapsed(); let validation_time = queued - start_time; @@ -440,45 +448,22 @@ async fn generate_stream( // StreamResponse end_reached = true; - let n_tokens = tokens.len(); - for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() { - // StreamResponse - let stream_token = if i < n_tokens - 1 { - StreamResponse { - token, - top_tokens, - generated_text: None, - details: None, - } - - }else{ - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - let output_text = if let Some(prompt) = &add_prompt { - prompt.to_owned() + &generated_text.text - }else{ - generated_text.text.to_owned() - }; - - tracing::debug!(parent: &span, "Output: {}", output_text); - tracing::info!(parent: &span, "Success"); - - StreamResponse { - token, - top_tokens, - generated_text: Some(output_text), - details - } - }; - yield Ok(Event::default().json_data(stream_token).unwrap()); + let mut output_text = generated_text.text; + if let Some(prompt) = add_prompt { + output_text = prompt + &output_text; } + tracing::debug!(parent: &span, "Output: {}", output_text); + tracing::info!(parent: &span, "Success"); + + let stream_token = StreamResponse { + token, + top_tokens, + generated_text: Some(output_text), + details + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()); break; } } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5000025c..4b1e2e54 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,6 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from typing import Optional +from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -77,9 +78,6 @@ except ImportError as e: if MISTRAL: __all__.append(FlashMistral) -SPECULATE = None - - def get_model( model_id: str, revision: Optional[str], @@ -89,7 +87,6 @@ def get_model( dtype: Optional[str], trust_remote_code: bool, ) -> Model: - global SPECULATE if dtype is None: # Keep it as default for now and let # every model resolve their own default dtype. @@ -101,7 +98,10 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") - SPECULATE = 2 + if speculate is not None: + set_speculate(speculate) + else: + set_speculate(2) if "facebook/galactica" in model_id: return GalacticaSharded( @@ -144,18 +144,22 @@ def get_model( medusa_config = config_dict model_id = config_dict["base_model_name_or_path"] revision = "main" - SPECULATE = config_dict["medusa_num_heads"] + speculate_medusa = config_dict["medusa_num_heads"] + if speculate is not None: + if speculate > speculate_medusa: + raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") + else: + set_speculate(speculate) + else: + set_speculate(speculate_medusa) + config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) method = "medusa" else: - if speculate is not None: - SPECULATE = speculate - else: - SPECULATE = 2 method = "n-gram" - logger.info(f"Using speculation {method} with {SPECULATE} input ids.") + logger.info(f"Using speculation {method} with {get_speculate()} input ids.") model_type = config_dict["model_type"] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index de946d21..4aa3637e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,7 +11,8 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict -from text_generation_server.models import Model +from text_generation_server.models import Model +from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, Tokens, @@ -192,8 +193,7 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past - from text_generation_server.models import SPECULATE - speculative_length = SPECULATE + speculative_length = get_speculate() total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks @@ -483,7 +483,7 @@ class FlashCausalLMBatch(Batch): total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks - speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1] + speculative_length = b.speculative_ids.shape[1] max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( @@ -589,7 +589,7 @@ class FlashCausalLMBatch(Batch): device=batches[0].next_token_chooser.device, ) - speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0) + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: @@ -825,16 +825,15 @@ class FlashCausalLM(Model): next_token_logits = out - from text_generation_server.models import SPECULATE next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) - speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] + speculative_length = speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -1038,6 +1037,7 @@ class FlashCausalLM(Model): clean_up_tokenization_spaces=False, skip_special_tokens=False, ) + prefill_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] ) diff --git a/server/text_generation_server/utils/speculate.py b/server/text_generation_server/utils/speculate.py new file mode 100644 index 00000000..229f5b8f --- /dev/null +++ b/server/text_generation_server/utils/speculate.py @@ -0,0 +1,12 @@ + +SPECULATE = None + +def get_speculate(): + global SPECULATE + return SPECULATE + +def set_speculate(speculate: int): + global SPECULATE + SPECULATE = speculate + +