From b9dffbd5121e0d3600b5bd42f0307404815eda70 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:50:37 +0200 Subject: [PATCH] python now uses v3 --- benchmark/src/generation.rs | 7 +- benchmark/src/lib.rs | 2 +- benchmark/src/main.rs | 2 +- router/client/src/lib.rs | 4 +- router/src/infer/mod.rs | 28 +- router/src/infer/v2/mod.rs | 2 +- router/src/infer/v2/queue.rs | 14 +- router/src/infer/v2/scheduler.rs | 36 +- router/src/infer/v3/mod.rs | 5 +- router/src/infer/v3/queue.rs | 13 +- .../src/infer/v3/{infer.rs => scheduler.rs} | 548 ++---------------- router/src/lib.rs | 1 - router/src/main.rs | 2 - router/src/server.rs | 242 +++++--- server/Makefile | 4 +- 15 files changed, 239 insertions(+), 671 deletions(-) rename router/src/infer/v3/{infer.rs => scheduler.rs} (76%) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 8c07e62b..f49d786a 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,8 +1,9 @@ use std::time::{Duration, Instant}; -use text_generation_client::{ - Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request, - ShardedClient, StoppingCriteriaParameters, +use text_generation_client::v2::{ + Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, + StoppingCriteriaParameters, }; +use text_generation_client::ClientError; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 638c6514..048c86af 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use text_generation_client::v2::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2d89e045..1e45c1dd 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -4,7 +4,7 @@ /// and: https://github.com/orhun/rust-tui-template use clap::Parser; use std::path::Path; -use text_generation_client::ShardedClient; +use text_generation_client::v2::ShardedClient; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index c0c1274a..0663e301 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -47,9 +47,7 @@ impl From for ClientError { impl From for ClientError { fn from(err: transport::Error) -> Self { - let err = Self::Connection(err.to_string()); - tracing::error!("{err}"); - err + Self::Connection(err.to_string()) } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index db5f4943..20630c1b 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,36 +1,35 @@ mod health; pub(crate) mod v2; -// pub(crate) mod v3; +pub(crate) mod v3; pub(crate) use health::HealthCheck; -use crate::validation::{Validation, ValidationError, ValidGenerateRequest}; +use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, - HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, - PrefillToken, Text, TextMessage, Token, + ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, }; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use serde_json::{json, Map, Value}; use std::collections::HashMap; -use std::sync::{ - Arc, -}; +use std::sync::Arc; use thiserror::Error; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; -use tracing::{instrument}; - +use tracing::instrument; pub(crate) trait Scheduler { - fn schedule(&self, request: ValidGenerateRequest, permit: OwnedSemaphorePermit) -> Result; + fn schedule( + &self, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, + ) -> Result; } - /// Inference struct #[derive(Clone)] pub struct Infer { @@ -44,8 +43,6 @@ pub struct Infer { limit_concurrent_requests: Arc, } - - impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -462,8 +459,6 @@ pub(crate) struct GeneratedText { pub(crate) seed: Option, } - - #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message @@ -525,4 +520,3 @@ impl InferError { } } } - diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index f87d863c..8b4f6bab 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,4 +1,4 @@ -mod scheduler; mod queue; +mod scheduler; pub(crate) use scheduler::SchedulerV2; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 057b1804..4a041ea7 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -1,5 +1,7 @@ use crate::infer::{InferError, InferStreamResponse}; -use crate::validation::{ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters}; +use crate::validation::{ + ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, +}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; @@ -400,9 +402,6 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, - }; use tracing::info_span; fn default_entry() -> ( @@ -417,7 +416,7 @@ mod tests { input_length: 0, truncate: 0, decoder_input_details: false, - parameters: NextTokenChooserParameters { + parameters: ValidParameters { temperature: 0.0, top_k: 0, top_p: 0.0, @@ -427,10 +426,9 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, + grammar: None, }, - stopping_parameters: StoppingCriteriaParameters { + stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, stop_sequences: vec![], diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index f0815168..ba6f520d 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,6 +1,9 @@ /// Batching and inference logic - -use crate::infer::v2::queue::{Queue, Entry}; +use crate::infer::v2::queue::{Entry, Queue}; +use crate::infer::{ + GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, +}; +use crate::validation::ValidGenerateRequest; use crate::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ @@ -8,14 +11,12 @@ use std::sync::{ Arc, }; use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::{ClientError}; +use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; -use crate::infer::{GeneratedText, GenerateStreamResponse, InferError, InferStreamResponse, Scheduler}; -use crate::validation::ValidGenerateRequest; pub(crate) struct SchedulerV2 { /// Request queue @@ -25,6 +26,7 @@ pub(crate) struct SchedulerV2 { } impl SchedulerV2 { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, waiting_served_ratio: f32, @@ -55,14 +57,18 @@ impl SchedulerV2 { Self { queue, - batching_task_notifier + batching_task_notifier, } } } impl Scheduler for SchedulerV2 { #[instrument(skip_all)] - fn schedule(&self, request: ValidGenerateRequest, permit: OwnedSemaphorePermit) -> Result { + fn schedule( + &self, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, + ) -> Result { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); let input_length = request.input_length; @@ -90,7 +96,6 @@ impl Scheduler for SchedulerV2 { } } - /// Batching logic /// Will be launched in a background Tokio task /// @@ -381,8 +386,8 @@ fn send_responses( let prefill_tokens = prefill_tokens .ids .into_iter() - .zip(prefill_tokens.logprobs.into_iter()) - .zip(prefill_tokens.texts.into_iter()) + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); @@ -473,7 +478,8 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { impl From for GeneratedText { fn from(value: text_generation_client::v2::GeneratedText) -> Self { - let v2_finish_reason = text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap(); + let v2_finish_reason = + text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v2_finish_reason { text_generation_client::v2::FinishReason::Length => FinishReason::Length, text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, @@ -767,10 +773,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/router/src/infer/v3/mod.rs b/router/src/infer/v3/mod.rs index 101f7b60..4299baf3 100644 --- a/router/src/infer/v3/mod.rs +++ b/router/src/infer/v3/mod.rs @@ -1,5 +1,4 @@ -mod infer; mod queue; +mod scheduler; -pub(crate) use infer::{Infer, InferError, InferStreamResponse, InferResponse, ToolGrammar}; -pub(crate) use queue::{Entry, Queue}; +pub(crate) use scheduler::SchedulerV3; diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index a12bf0ff..f13cf936 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -1,4 +1,4 @@ -use crate::infer::v3::{InferError, InferStreamResponse}; +use crate::infer::{InferError, InferStreamResponse}; use crate::validation::{ ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; @@ -58,7 +58,6 @@ impl Queue { Self { queue_sender } } - /// Append an entry to the queue #[instrument(skip_all)] pub(crate) fn append(&self, entry: Entry) { // Send append command to the background task managing the state @@ -397,9 +396,6 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, - }; use tracing::info_span; fn default_entry() -> ( @@ -414,7 +410,7 @@ mod tests { input_length: 0, truncate: 0, decoder_input_details: false, - parameters: NextTokenChooserParameters { + parameters: ValidParameters { temperature: 0.0, top_k: 0, top_p: 0.0, @@ -424,10 +420,9 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, - grammar: String::new(), - grammar_type: ProtoGrammarType::None as i32, + grammar: None, }, - stopping_parameters: StoppingCriteriaParameters { + stopping_parameters: ValidStoppingParameters { ignore_eos_token: false, max_new_tokens: 1, stop_sequences: vec![], diff --git a/router/src/infer/v3/infer.rs b/router/src/infer/v3/scheduler.rs similarity index 76% rename from router/src/infer/v3/infer.rs rename to router/src/infer/v3/scheduler.rs index c0522e9d..257d191f 100644 --- a/router/src/infer/v3/infer.rs +++ b/router/src/infer/v3/scheduler.rs @@ -1,80 +1,46 @@ /// Batching and inference logic - -use crate::infer::v3::{Queue, Entry}; -use crate::validation::{Validation, ValidationError}; -use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, - HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, - PrefillToken, Text, TextMessage, Token, +use crate::infer::v3::queue::{Entry, Queue}; +use crate::infer::{ + GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler, }; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; -use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; +use crate::validation::ValidGenerateRequest; +use crate::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::{v3, ClientError}; -use thiserror::Error; +use text_generation_client::ClientError; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; use tracing::{info_span, instrument, Instrument, Span}; -/// Inference struct -#[derive(Clone)] -pub struct Infer { - /// Validation - validation: Validation, +pub(crate) struct SchedulerV3 { /// Request queue queue: Queue, - /// Shared state - shared: Arc, - /// Chat template - chat_template: Option, - /// Inference limit - limit_concurrent_requests: Arc, + /// Notify batcher on queue appends + batching_task_notifier: Arc, } -/// Infer shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} - -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -impl Infer { +impl SchedulerV3 { #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, - validation: Validation, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, - max_concurrent_requests: usize, requires_padding: bool, window_size: Option, speculate: u32, generation_health: Arc, - tokenizer_config: HubTokenizerConfig, - processor_config: HubProcessorConfig, ) -> Self { let queue = Queue::new(requires_padding, 16, window_size, speculate); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); + let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( @@ -85,72 +51,31 @@ impl Infer { max_waiting_tokens, max_batch_size, queue.clone(), - shared.clone(), + batching_task_notifier.clone(), generation_health, )); - let chat_template = tokenizer_config - .chat_template - .or(processor_config.chat_template) - .and_then(|t| match t { - ChatTemplateVersions::Single(template) => Some(template), - ChatTemplateVersions::Multiple(templates) => templates - .into_iter() - .find(|t| t.name == "default") - .map(|t| t.template), - }) - .map(|t| { - // .strip() is not supported in minijinja - // .capitalize() is not supported in minijinja but we can use | capitalize - let t = t - .replace(".strip()", " | trim") - .replace(".capitalize()", " | capitalize"); - ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) - }); - - // Inference limit with a semaphore - let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - Self { - validation, queue, - shared, - chat_template, - limit_concurrent_requests: semaphore, + batching_task_notifier, } } +} - /// Add a new request to the queue and return a stream of InferStreamResponse +impl Scheduler for SchedulerV3 { #[instrument(skip_all)] - pub(crate) async fn generate_stream( + fn schedule( &self, - request: GenerateRequest, + request: ValidGenerateRequest, + permit: OwnedSemaphorePermit, ) -> Result { - // Limit concurrent requests by acquiring a permit from the semaphore - let permit = self - .clone() - .limit_concurrent_requests - .try_acquire_owned() - .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); - tracing::error!("{err}"); - err - })?; - - // Validate request - let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - err - })?; - // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = valid_request.input_length; + let input_length = request.input_length; // Append the request to the queue self.queue.append(Entry { - request: valid_request, + request, response_tx, span: Span::current(), temp_span: None, @@ -160,7 +85,7 @@ impl Infer { // Notify the background task that we have a new entry in the queue that needs // to be batched - self.shared.batching_task.notify_one(); + self.batching_task_notifier.notify_one(); // Return stream Ok(( @@ -169,335 +94,6 @@ impl Infer { UnboundedReceiverStream::new(response_rx), )) } - - /// Tokenizer the input - #[instrument(skip_all)] - pub(crate) async fn tokenize( - &self, - request: GenerateRequest, - ) -> Result, InferError> { - // Tokenize request - let inputs = request.inputs; - let truncate = request.parameters.truncate; - let encoding = self - .validation - .tokenize(inputs, truncate) - .await - .map_err(|err| { - tracing::error!("Tokenization {err}"); - err - })?; - - // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) - } - - /// Apply the chat template to the chat request - #[instrument(skip_all)] - pub(crate) fn apply_chat_template( - &self, - messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - self.chat_template - .as_ref() - .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) - .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); - tracing::error!("{e}"); - e - }) - } - - /// Add a new request to the queue and return a InferResponse - #[instrument(skip_all)] - pub(crate) async fn generate( - &self, - request: GenerateRequest, - ) -> Result { - let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); - - // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; - - // Return values - let mut result_prefill = Vec::new(); - let mut result_tokens = Vec::new(); - let mut result_top_tokens = Vec::new(); - let mut result_generated_text = None; - let mut result_start = None; - let mut result_queued = None; - - // Iterate on stream - while let Some(response) = stream.next().await { - match response? { - // Add prefill tokens - InferStreamResponse::Prefill(prefill_tokens) => { - result_prefill = prefill_tokens; - } - // Push last token - InferStreamResponse::Intermediate { token, top_tokens } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - } - // Final message - // Set return values - InferStreamResponse::End { - token, - generated_text, - start, - queued, - top_tokens, - } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - result_generated_text = Some(generated_text); - result_start = Some(start); - result_queued = Some(queued) - } - } - } - - // Check that we received a `InferStreamResponse::End` message - if let (Some(generated_text), Some(queued), Some(start)) = - (result_generated_text, result_queued, result_start) - { - Ok(InferResponse { - prefill: result_prefill, - _input_length, - tokens: result_tokens, - generated_text, - queued, - start, - top_tokens: if use_top_tokens { - result_top_tokens - } else { - Vec::new() - }, - }) - } else { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); - tracing::error!("{err}"); - Err(err) - } - } - /// Add best_of new requests to the queue and return a InferResponse of the sequence with - /// the highest log probability per token - #[instrument(skip(self, request))] - pub(crate) async fn generate_best_of( - &self, - request: GenerateRequest, - best_of: usize, - ) -> Result<(InferResponse, Vec), InferError> { - // validate best_of parameter separately - let best_of = self.validation.validate_best_of(best_of)?; - - // create multiple generate requests - let mut infer_responses: Vec = - try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; - - // get the sequence with the highest log probability per token - let mut max_index = 0; - let mut max_logprob: f32 = f32::MIN; - - for (i, response) in infer_responses.iter().enumerate() { - // mean logprobs of the generated tokens - let sequence_logprob = response - .tokens - .iter() - .map(|token| token.logprob) - .sum::() - / response.tokens.len() as f32; - - // set best sequence - if sequence_logprob > max_logprob { - max_index = i; - max_logprob = sequence_logprob; - } - } - let best_response = infer_responses.remove(max_index); - Ok((best_response, infer_responses)) - } -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - let mut env = Box::new(Environment::new()); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token, - eos_token, - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) - } } /// Batching logic @@ -505,7 +101,7 @@ impl ToolGrammar { /// /// Batches requests and sends them to the inference server #[allow(clippy::too_many_arguments)] -async fn batching_task( +pub(crate) async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -513,13 +109,13 @@ async fn batching_task( max_waiting_tokens: usize, max_batch_size: Option, queue: Queue, - shared: Arc, + notifier: Arc, generation_health: Arc, ) { // Infinite loop loop { // Wait for a notification from the Infer struct - shared.batching_task.notified().await; + notifier.notified().await; // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests @@ -790,8 +386,8 @@ fn send_responses( let prefill_tokens = prefill_tokens .ids .into_iter() - .zip(prefill_tokens.logprobs.into_iter()) - .zip(prefill_tokens.texts.into_iter()) + .zip(prefill_tokens.logprobs) + .zip(prefill_tokens.texts) .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); @@ -880,28 +476,14 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { }); } -/// Type alias for generation responses -pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, - u32, // input_length - UnboundedReceiverStream>, -); - -#[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, -} - -impl From for GeneratedText { - fn from(value: v3::GeneratedText) -> Self { - let v3_finish_reason = v3::FinishReason::try_from(value.finish_reason).unwrap(); +impl From for GeneratedText { + fn from(value: text_generation_client::v3::GeneratedText) -> Self { + let v3_finish_reason = + text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v3_finish_reason { - v3::FinishReason::Length => FinishReason::Length, - v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - v3::FinishReason::StopSequence => FinishReason::StopSequence, + text_generation_client::v3::FinishReason::Length => FinishReason::Length, + text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence, }; Self { @@ -913,68 +495,6 @@ impl From for GeneratedText { } } -#[derive(Debug)] -pub(crate) enum InferStreamResponse { - // Optional first message - Prefill(Vec), - // Intermediate messages - Intermediate { - token: Token, - top_tokens: Vec, - }, - // Last message - End { - token: Token, - top_tokens: Vec, - generated_text: GeneratedText, - start: Instant, - queued: Instant, - }, -} - -#[derive(Debug)] -pub(crate) struct InferResponse { - /// input_length is the input as perceived by the rust tokenizer in the - /// validation pathway. It is redundant with prefill.len() but prefill - /// has data only if the user asked for it. This will always be filled. - pub(crate) _input_length: u32, - pub(crate) prefill: Vec, - pub(crate) tokens: Vec, - pub(crate) generated_text: GeneratedText, - pub(crate) queued: Instant, - pub(crate) start: Instant, - pub(crate) top_tokens: Vec>, -} - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), - #[error("Model is overloaded")] - Overloaded(#[from] TryAcquireError), - #[error("Input validation error: {0}")] - ValidationError(#[from] ValidationError), - #[error("Incomplete generation")] - IncompleteGeneration, - #[error("Template error: {0}")] - TemplateError(#[from] minijinja::Error), - #[error("Tool error: {0}")] - ToolError(String), -} - -impl InferError { - pub(crate) fn error_type(&self) -> &str { - match self { - InferError::GenerationError(_) => "generation", - InferError::Overloaded(_) => "overloaded", - InferError::ValidationError(_) => "validation", - InferError::IncompleteGeneration => "incomplete_generation", - InferError::TemplateError(_) => "template_error", - InferError::ToolError(_) => "tool_error", - } - } -} - // tests #[cfg(test)] mod tests { diff --git a/router/src/lib.rs b/router/src/lib.rs index 4a1eefb8..b6902c49 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,5 +1,4 @@ /// Text Generation Inference Webserver - pub mod config; mod infer; pub mod server; diff --git a/router/src/main.rs b/router/src/main.rs index 277ad8b0..c4203dbc 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -314,8 +314,6 @@ async fn main() -> Result<(), RouterError> { Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", }; - - // Determine the server port based on the feature and environment variable. let port = if cfg!(feature = "google") { std::env::var("AIP_HTTP_PORT") diff --git a/router/src/server.rs b/router/src/server.rs index 2bfb9aa8..30479b0e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,14 +1,15 @@ /// HTTP Server logic - use crate::config::Config; -use crate::infer::HealthCheck; +use crate::infer::v2::SchedulerV2; +use crate::infer::v3::SchedulerV3; +use crate::infer::{HealthCheck, Scheduler}; use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, - Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, - TokenizeResponse, Usage, Validation, + GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, + Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, + Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -35,7 +36,8 @@ use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; -use text_generation_client::{v2::ShardedClient, ClientError}; +use text_generation_client::{v2, v3, ClientError, ShardInfo}; +use thiserror::Error; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -45,8 +47,6 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use thiserror::Error; -use crate::infer::v2::SchedulerV2; /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( @@ -218,9 +218,7 @@ async fn generate_internal( BestOfSequence { generated_text: output_text, - finish_reason: FinishReason::from( - response.generated_text.finish_reason, - ), + finish_reason: response.generated_text.finish_reason, generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, @@ -232,7 +230,7 @@ async fn generate_internal( }); Some(Details { - finish_reason: FinishReason::from(response.generated_text.finish_reason), + finish_reason: response.generated_text.finish_reason, generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, @@ -374,7 +372,7 @@ async fn generate_stream( Json(req): Json, ) -> ( HeaderMap, - Sse>>, + Sse>>, ) { let span = tracing::Span::current(); let on_message_callback = |stream_token: StreamResponse| { @@ -393,7 +391,7 @@ async fn generate_stream_internal( Json(req): Json, on_message_callback: impl Fn(StreamResponse) -> Event, span: tracing::Span, -) -> (HeaderMap, impl Stream>) { +) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -473,7 +471,7 @@ async fn generate_stream_internal( // Token details let details = match details { true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), + finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, }), @@ -728,7 +726,7 @@ async fn completions( on_message_callback, span_clone.clone(), ) - .await; + .await; // send and dont wait for response let _ = header_tx.send(header_map); @@ -835,7 +833,7 @@ async fn completions( Json(generate_request), span_clone, ) - .await; + .await; result.map(|(headers, generation)| (index, headers, generation)) }; responses.push(response_future); @@ -1152,7 +1150,7 @@ async fn chat_completions( on_message_callback, span, ) - .await; + .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1239,15 +1237,15 @@ example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument( -skip_all, -fields( -total_time, -validation_time, -queue_time, -inference_time, -time_per_token, -seed, -) + skip_all, + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) )] async fn vertex_compatibility( Extension(infer): Extension, @@ -1292,17 +1290,17 @@ async fn vertex_compatibility( Json(generate_request), span.clone(), ) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) } }) .collect::>() @@ -1476,14 +1474,20 @@ pub async fn run( // Create state // Open connection, get model info and warmup - let (scheduler, health_ext, shard_info, max_batch_total_tokens) = { + let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( + Arc, + HealthCheck, + ShardInfo, + u32, + ) = { // Helper function to check both v2 and v3 let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { match max_supported_batch_total_tokens { // Older models do not support automatic max-batch-total-tokens None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); tracing::warn!("Model does not support automatic max batch total tokens"); Ok(max_batch_total_tokens) } @@ -1492,12 +1496,12 @@ pub async fn run( // Warn if user added his own max-batch-total-tokens as we will ignore it if max_batch_total_tokens.is_some() { tracing::warn!( - "`--max-batch-total-tokens` is deprecated for Flash \ + "`--max-batch-total-tokens` is deprecated for Flash \ Attention models." - ); + ); tracing::warn!( - "Inferred max batch total tokens: {max_supported_batch_total_tokens}" - ); + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); } if max_total_tokens as u32 > max_supported_batch_total_tokens { return Err(WebServerError::NotEnoughMemory(max_total_tokens)); @@ -1508,51 +1512,100 @@ pub async fn run( } }; - let generation_health = Arc::new(AtomicBool::new(false)); - // Try to open a v3 client - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; + match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await { + Ok(mut sharded_client) => { + // server is running on v3 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(WebServerError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens(sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)?)?; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(WebServerError::Warmup)?, + )?; - let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let scheduler = SchedulerV2::new( - sharded_client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - ); + let health_ext = + HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); + let scheduler = Arc::new(SchedulerV3::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + )); + tracing::info!("Using scheduler V3"); - (scheduler, health_ext, shard_info, max_batch_total_tokens) + (scheduler, health_ext, shard_info, max_batch_total_tokens) + } + Err(_) => { + let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(WebServerError::Connection)?; + + // server is running on v2 + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(WebServerError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens( + sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(WebServerError::Warmup)?, + )?; + + let health_ext = + HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); + let scheduler = Arc::new(SchedulerV2::new( + sharded_client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + )); + tracing::info!("Using scheduler V2"); + + (scheduler, health_ext, shard_info, max_batch_total_tokens) + } + } }; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); let validation = Validation::new( validation_workers, @@ -1566,7 +1619,13 @@ pub async fn run( grammar_support, ); - let infer = Infer::new(Arc::new(scheduler), validation, max_concurrent_requests, tokenizer_config, processor_config); + let infer = Infer::new( + scheduler, + validation, + max_concurrent_requests, + tokenizer_config, + processor_config, + ); // Duration buckets let duration_matcher = Matcher::Suffix(String::from("duration")); @@ -1659,8 +1718,8 @@ pub async fn run( #[derive(OpenApi)] #[openapi( - paths(vertex_compatibility), - components(schemas(VertexInstance, VertexRequest, VertexResponse)) + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) )] struct VertextApiDoc; @@ -1756,7 +1815,8 @@ pub async fn run( let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) - .await.map_err(|err| WebServerError::Axum(Box::new(err)))?; + .await + .map_err(|err| WebServerError::Axum(Box::new(err)))?; } Ok(()) } @@ -1770,7 +1830,7 @@ async fn shutdown_signal() { }; #[cfg(unix)] - let terminate = async { + let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() @@ -1778,7 +1838,7 @@ async fn shutdown_signal() { }; #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, @@ -1836,4 +1896,4 @@ pub enum WebServerError { NotEnoughMemory(usize), #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), -} \ No newline at end of file +} diff --git a/server/Makefile b/server/Makefile index 32d01709..312f14df 100644 --- a/server/Makefile +++ b/server/Makefile @@ -12,8 +12,8 @@ gen-server: # Compile protos pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir mkdir text_generation_server/pb || true - python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \ - --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto + python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \ + --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py