diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index adedeb58..42e33ac9 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -1,5 +1,599 @@ mod health; pub(crate) mod v2; -pub(crate) mod v3; +// pub(crate) mod v3; pub(crate) use health::HealthCheck; + +use crate::validation::{Validation, ValidationError, ValidGenerateRequest}; +use crate::{ + ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, + HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, + PrefillToken, Text, TextMessage, Token, +}; +use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; +use futures::future::try_join_all; +use minijinja::{Environment, ErrorKind, Template}; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicBool}, + Arc, +}; +use text_generation_client::v2::{ShardedClient}; +use thiserror::Error; +use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tracing::{instrument, Span}; + +/// Queue entry +#[derive(Debug)] +pub(crate) struct Entry { + /// Request + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: mpsc::UnboundedSender>, + /// Span that will live as long as entry + pub span: Span, + /// Temporary span used as a guard when logging inference, wait times... + pub temp_span: Option, + /// Instant when this entry was queued + pub queue_time: Instant, + /// Instant when this entry was added to a batch + pub batch_time: Option, +} + +pub(crate) trait InferQueue { + /// Append an entry to the queue + #[instrument(skip_all)] + fn append(&self, entry: Entry); +} + + +/// Inference struct +#[derive(Clone)] +pub struct Infer { + /// Validation + validation: Validation, + /// Request queue + queue: Arc, + /// Notify batcher on queue appends + batching_task_notifier: Arc, + /// Chat template + chat_template: Option, + /// Inference limit + limit_concurrent_requests: Arc, +} + + + +impl Infer { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + client: ShardedClient, + validation: Validation, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: u32, + max_waiting_tokens: usize, + max_batch_size: Option, + max_concurrent_requests: usize, + requires_padding: bool, + window_size: Option, + speculate: u32, + generation_health: Arc, + tokenizer_config: HubTokenizerConfig, + processor_config: HubProcessorConfig, + ) -> Self { + let queue = v2::Queue::new(requires_padding, 16, window_size, speculate); + let batching_task_notifier = Arc::new(Notify::new()); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(v2::batching_task( + client, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + queue.clone(), + batching_task_notifier.clone(), + generation_health, + )); + + let chat_template = tokenizer_config + .chat_template + .or(processor_config.chat_template) + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + // .strip() is not supported in minijinja + // .capitalize() is not supported in minijinja but we can use | capitalize + let t = t + .replace(".strip()", " | trim") + .replace(".capitalize()", " | capitalize"); + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }); + + // Inference limit with a semaphore + let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + + Self { + validation, + queue: Arc::new(queue), + batching_task_notifier, + chat_template, + limit_concurrent_requests: semaphore, + } + } + + /// Add a new request to the queue and return a stream of InferStreamResponse + #[instrument(skip_all)] + pub(crate) async fn generate_stream( + &self, + request: GenerateRequest, + ) -> Result { + // Limit concurrent requests by acquiring a permit from the semaphore + let permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + // Validate request + let valid_request = self.validation.validate(request).await.map_err(|err| { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + err + })?; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; + + // Append the request to the queue + self.queue.append(Entry { + request: valid_request, + response_tx, + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the queue that needs + // to be batched + self.batching_task_notifier.notify_one(); + + // Return stream + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) + } + + /// Tokenizer the input + #[instrument(skip_all)] + pub(crate) async fn tokenize( + &self, + request: GenerateRequest, + ) -> Result, InferError> { + // Tokenize request + let inputs = request.inputs; + let truncate = request.parameters.truncate; + let encoding = self + .validation + .tokenize(inputs, truncate) + .await + .map_err(|err| { + tracing::error!("Tokenization {err}"); + err + })?; + + // Return Encoding + Ok(encoding.map(|(encoding, _)| encoding)) + } + + /// Apply the chat template to the chat request + #[instrument(skip_all)] + pub(crate) fn apply_chat_template( + &self, + messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + self.chat_template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .apply(messages, grammar_with_prompt) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + e + }) + } + + /// Add a new request to the queue and return a InferResponse + #[instrument(skip_all)] + pub(crate) async fn generate( + &self, + request: GenerateRequest, + ) -> Result { + let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); + + // Create stream and keep semaphore permit as long as generate lives + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + + // Return values + let mut result_prefill = Vec::new(); + let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + // Iterate on stream + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill(prefill_tokens) => { + result_prefill = prefill_tokens; + } + // Push last token + InferStreamResponse::Intermediate { token, top_tokens } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } + // Final message + // Set return values + InferStreamResponse::End { + token, + generated_text, + start, + queued, + top_tokens, + } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + } + } + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + prefill: result_prefill, + _input_length, + tokens: result_tokens, + generated_text, + queued, + start, + top_tokens: if use_top_tokens { + result_top_tokens + } else { + Vec::new() + }, + }) + } else { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); + Err(err) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with + /// the highest log probability per token + #[instrument(skip(self, request))] + pub(crate) async fn generate_best_of( + &self, + request: GenerateRequest, + best_of: usize, + ) -> Result<(InferResponse, Vec), InferError> { + // validate best_of parameter separately + let best_of = self.validation.validate_best_of(best_of)?; + + // create multiple generate requests + let mut infer_responses: Vec = + try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + + // get the sequence with the highest log probability per token + let mut max_index = 0; + let mut max_logprob: f32 = f32::MIN; + + for (i, response) in infer_responses.iter().enumerate() { + // mean logprobs of the generated tokens + let sequence_logprob = response + .tokens + .iter() + .map(|token| token.logprob) + .sum::() + / response.tokens.len() as f32; + + // set best sequence + if sequence_logprob > max_logprob { + max_index = i; + max_logprob = sequence_logprob; + } + } + let best_response = infer_responses.remove(max_index); + Ok((best_response, infer_responses)) + } +} + +/// Raise a exception (custom function) used in the chat templates +fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + +#[derive(Clone)] +struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, + use_default_tool_template: bool, +} + +impl ChatTemplate { + fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + let mut env = Box::new(Environment::new()); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { + template, + bos_token, + eos_token, + use_default_tool_template, + } + } + + fn apply( + &self, + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); + } + } + } + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools: None, + tools_prompt: None, + }) + .map_err(InferError::TemplateError) + } +} + +pub struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} + +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + u32, // input_length + UnboundedReceiverStream>, +); + +#[derive(Debug)] +pub(crate) struct GeneratedText { + pub(crate) text: String, + pub(crate) generated_tokens: u32, + pub(crate) finish_reason: FinishReason, + pub(crate) seed: Option, +} + + + +#[derive(Debug)] +pub(crate) enum InferStreamResponse { + // Optional first message + Prefill(Vec), + // Intermediate messages + Intermediate { + token: Token, + top_tokens: Vec, + }, + // Last message + End { + token: Token, + top_tokens: Vec, + generated_text: GeneratedText, + start: Instant, + queued: Instant, + }, +} + +#[derive(Debug)] +pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, + pub(crate) prefill: Vec, + pub(crate) tokens: Vec, + pub(crate) generated_text: GeneratedText, + pub(crate) queued: Instant, + pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, +} + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), + #[error("Model is overloaded")] + Overloaded(#[from] TryAcquireError), + #[error("Input validation error: {0}")] + ValidationError(#[from] ValidationError), + #[error("Incomplete generation")] + IncompleteGeneration, + #[error("Template error: {0}")] + TemplateError(#[from] minijinja::Error), + #[error("Tool error: {0}")] + ToolError(String), +} + +impl InferError { + pub(crate) fn error_type(&self) -> &str { + match self { + InferError::GenerationError(_) => "generation", + InferError::Overloaded(_) => "overloaded", + InferError::ValidationError(_) => "validation", + InferError::IncompleteGeneration => "incomplete_generation", + InferError::TemplateError(_) => "template_error", + InferError::ToolError(_) => "tool_error", + } + } +} + diff --git a/router/src/infer/v2/infer.rs b/router/src/infer/v2/batcher.rs similarity index 74% rename from router/src/infer/v2/infer.rs rename to router/src/infer/v2/batcher.rs index d91b7f41..3cfc9118 100644 --- a/router/src/infer/v2/infer.rs +++ b/router/src/infer/v2/batcher.rs @@ -1,511 +1,27 @@ /// Batching and inference logic -use crate::infer::v2::{Queue, Entry}; -use crate::validation::{Validation, ValidationError}; -use crate::{ - ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, - HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, - PrefillToken, Text, TextMessage, Token, -}; -use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; -use futures::future::try_join_all; -use minijinja::{Environment, ErrorKind, Template}; +use crate::infer::Entry; +use crate::infer::v2::{Queue}; +use crate::{FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient}; -use text_generation_client::{v2, ClientError}; -use thiserror::Error; +use text_generation_client::{ClientError}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::Notify; use tokio::time::Instant; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_stream::StreamExt; -use tracing::{info_span, instrument, Instrument, Span}; - -/// Inference struct -#[derive(Clone)] -pub struct Infer { - /// Validation - validation: Validation, - /// Request queue - queue: Queue, - /// Shared state - shared: Arc, - /// Chat template - chat_template: Option, - /// Inference limit - limit_concurrent_requests: Arc, -} - -/// Infer shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} - -/// Raise a exception (custom function) used in the chat templates -fn raise_exception(err_text: String) -> Result { - Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) -} - -impl Infer { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - client: ShardedClient, - validation: Validation, - waiting_served_ratio: f32, - max_batch_prefill_tokens: u32, - max_batch_total_tokens: u32, - max_waiting_tokens: usize, - max_batch_size: Option, - max_concurrent_requests: usize, - requires_padding: bool, - window_size: Option, - speculate: u32, - generation_health: Arc, - tokenizer_config: HubTokenizerConfig, - processor_config: HubProcessorConfig, - ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); - - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - queue.clone(), - shared.clone(), - generation_health, - )); - - let chat_template = tokenizer_config - .chat_template - .or(processor_config.chat_template) - .and_then(|t| match t { - ChatTemplateVersions::Single(template) => Some(template), - ChatTemplateVersions::Multiple(templates) => templates - .into_iter() - .find(|t| t.name == "default") - .map(|t| t.template), - }) - .map(|t| { - // .strip() is not supported in minijinja - // .capitalize() is not supported in minijinja but we can use | capitalize - let t = t - .replace(".strip()", " | trim") - .replace(".capitalize()", " | capitalize"); - ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) - }); - - // Inference limit with a semaphore - let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - - Self { - validation, - queue, - shared, - chat_template, - limit_concurrent_requests: semaphore, - } - } - - /// Add a new request to the queue and return a stream of InferStreamResponse - #[instrument(skip_all)] - pub(crate) async fn generate_stream( - &self, - request: GenerateRequest, - ) -> Result { - // Limit concurrent requests by acquiring a permit from the semaphore - let permit = self - .clone() - .limit_concurrent_requests - .try_acquire_owned() - .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); - tracing::error!("{err}"); - err - })?; - - // Validate request - let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); - tracing::error!("{err}"); - err - })?; - - // MPSC channel to communicate with the background batching task - let (response_tx, response_rx) = mpsc::unbounded_channel(); - let input_length = valid_request.input_length; - - // Append the request to the queue - self.queue.append(Entry { - request: valid_request, - response_tx, - span: Span::current(), - temp_span: None, - queue_time: Instant::now(), - batch_time: None, - }); - - // Notify the background task that we have a new entry in the queue that needs - // to be batched - self.shared.batching_task.notify_one(); - - // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) - } - - /// Tokenizer the input - #[instrument(skip_all)] - pub(crate) async fn tokenize( - &self, - request: GenerateRequest, - ) -> Result, InferError> { - // Tokenize request - let inputs = request.inputs; - let truncate = request.parameters.truncate; - let encoding = self - .validation - .tokenize(inputs, truncate) - .await - .map_err(|err| { - tracing::error!("Tokenization {err}"); - err - })?; - - // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) - } - - /// Apply the chat template to the chat request - #[instrument(skip_all)] - pub(crate) fn apply_chat_template( - &self, - messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - self.chat_template - .as_ref() - .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) - .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); - tracing::error!("{e}"); - e - }) - } - - /// Add a new request to the queue and return a InferResponse - #[instrument(skip_all)] - pub(crate) async fn generate( - &self, - request: GenerateRequest, - ) -> Result { - let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); - - // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; - - // Return values - let mut result_prefill = Vec::new(); - let mut result_tokens = Vec::new(); - let mut result_top_tokens = Vec::new(); - let mut result_generated_text = None; - let mut result_start = None; - let mut result_queued = None; - - // Iterate on stream - while let Some(response) = stream.next().await { - match response? { - // Add prefill tokens - InferStreamResponse::Prefill(prefill_tokens) => { - result_prefill = prefill_tokens; - } - // Push last token - InferStreamResponse::Intermediate { token, top_tokens } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - } - // Final message - // Set return values - InferStreamResponse::End { - token, - generated_text, - start, - queued, - top_tokens, - } => { - result_tokens.push(token); - result_top_tokens.push(top_tokens); - result_generated_text = Some(generated_text); - result_start = Some(start); - result_queued = Some(queued) - } - } - } - - // Check that we received a `InferStreamResponse::End` message - if let (Some(generated_text), Some(queued), Some(start)) = - (result_generated_text, result_queued, result_start) - { - Ok(InferResponse { - prefill: result_prefill, - _input_length, - tokens: result_tokens, - generated_text, - queued, - start, - top_tokens: if use_top_tokens { - result_top_tokens - } else { - Vec::new() - }, - }) - } else { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); - tracing::error!("{err}"); - Err(err) - } - } - /// Add best_of new requests to the queue and return a InferResponse of the sequence with - /// the highest log probability per token - #[instrument(skip(self, request))] - pub(crate) async fn generate_best_of( - &self, - request: GenerateRequest, - best_of: usize, - ) -> Result<(InferResponse, Vec), InferError> { - // validate best_of parameter separately - let best_of = self.validation.validate_best_of(best_of)?; - - // create multiple generate requests - let mut infer_responses: Vec = - try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; - - // get the sequence with the highest log probability per token - let mut max_index = 0; - let mut max_logprob: f32 = f32::MIN; - - for (i, response) in infer_responses.iter().enumerate() { - // mean logprobs of the generated tokens - let sequence_logprob = response - .tokens - .iter() - .map(|token| token.logprob) - .sum::() - / response.tokens.len() as f32; - - // set best sequence - if sequence_logprob > max_logprob { - max_index = i; - max_logprob = sequence_logprob; - } - } - let best_response = infer_responses.remove(max_index); - Ok((best_response, infer_responses)) - } -} - -#[derive(Clone)] -struct ChatTemplate { - template: Template<'static, 'static>, - bos_token: Option, - eos_token: Option, - use_default_tool_template: bool, -} - -impl ChatTemplate { - fn new(template: String, bos_token: Option, eos_token: Option) -> Self { - let mut env = Box::new(Environment::new()); - let template_str = template.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); - // leaking env and template_str as read-only, static resources for performance. - let template = Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap(); - - Self { - template, - bos_token, - eos_token, - use_default_tool_template, - } - } - - fn apply( - &self, - mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, - ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text(Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - })); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools: None, - tools_prompt: None, - }) - .map_err(InferError::TemplateError) - } -} - -pub struct ToolGrammar {} - -impl ToolGrammar { - pub fn apply( - tools: Option>, - tool_choice: Option, - ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; - - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) - } -} +use tracing::{info_span, instrument, Instrument}; +use crate::infer::{GeneratedText, InferError, InferStreamResponse}; /// Batching logic /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server #[allow(clippy::too_many_arguments)] -async fn batching_task( +pub(crate) async fn batching_task( mut client: ShardedClient, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -513,13 +29,13 @@ async fn batching_task( max_waiting_tokens: usize, max_batch_size: Option, queue: Queue, - shared: Arc, + notifier: Arc, generation_health: Arc, ) { // Infinite loop loop { // Wait for a notification from the Infer struct - shared.batching_task.notified().await; + notifier.notified().await; // Get the next batch from the queue // This batch might be smaller than the maximum batch size if there are not enough requests @@ -880,28 +396,13 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { }); } -/// Type alias for generation responses -pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, - u32, // input_length - UnboundedReceiverStream>, -); - -#[derive(Debug)] -pub(crate) struct GeneratedText { - pub(crate) text: String, - pub(crate) generated_tokens: u32, - pub(crate) finish_reason: FinishReason, - pub(crate) seed: Option, -} - -impl From for GeneratedText { - fn from(value: v2::GeneratedText) -> Self { - let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap(); +impl From for GeneratedText { + fn from(value: text_generation_client::v2::GeneratedText) -> Self { + let v2_finish_reason = text_generation_client::v2::FinishReason::try_from(value.finish_reason).unwrap(); let finish_reason = match v2_finish_reason { - v2::FinishReason::Length => FinishReason::Length, - v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, - v2::FinishReason::StopSequence => FinishReason::StopSequence, + text_generation_client::v2::FinishReason::Length => FinishReason::Length, + text_generation_client::v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken, + text_generation_client::v2::FinishReason::StopSequence => FinishReason::StopSequence, }; Self { @@ -913,68 +414,6 @@ impl From for GeneratedText { } } -#[derive(Debug)] -pub(crate) enum InferStreamResponse { - // Optional first message - Prefill(Vec), - // Intermediate messages - Intermediate { - token: Token, - top_tokens: Vec, - }, - // Last message - End { - token: Token, - top_tokens: Vec, - generated_text: GeneratedText, - start: Instant, - queued: Instant, - }, -} - -#[derive(Debug)] -pub(crate) struct InferResponse { - /// input_length is the input as perceived by the rust tokenizer in the - /// validation pathway. It is redundant with prefill.len() but prefill - /// has data only if the user asked for it. This will always be filled. - pub(crate) _input_length: u32, - pub(crate) prefill: Vec, - pub(crate) tokens: Vec, - pub(crate) generated_text: GeneratedText, - pub(crate) queued: Instant, - pub(crate) start: Instant, - pub(crate) top_tokens: Vec>, -} - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), - #[error("Model is overloaded")] - Overloaded(#[from] TryAcquireError), - #[error("Input validation error: {0}")] - ValidationError(#[from] ValidationError), - #[error("Incomplete generation")] - IncompleteGeneration, - #[error("Template error: {0}")] - TemplateError(#[from] minijinja::Error), - #[error("Tool error: {0}")] - ToolError(String), -} - -impl InferError { - pub(crate) fn error_type(&self) -> &str { - match self { - InferError::GenerationError(_) => "generation", - InferError::Overloaded(_) => "overloaded", - InferError::ValidationError(_) => "validation", - InferError::IncompleteGeneration => "incomplete_generation", - InferError::TemplateError(_) => "template_error", - InferError::ToolError(_) => "tool_error", - } - } -} - // tests #[cfg(test)] mod tests { diff --git a/router/src/infer/v2/mod.rs b/router/src/infer/v2/mod.rs index 101f7b60..0af0e5cb 100644 --- a/router/src/infer/v2/mod.rs +++ b/router/src/infer/v2/mod.rs @@ -1,5 +1,5 @@ -mod infer; +mod batcher; mod queue; -pub(crate) use infer::{Infer, InferError, InferStreamResponse, InferResponse, ToolGrammar}; -pub(crate) use queue::{Entry, Queue}; +pub(crate) use batcher::batching_task; +pub(crate) use queue::Queue; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 10fad191..d62c8cdc 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -1,6 +1,6 @@ -use crate::infer::v2::{InferError, InferStreamResponse}; +use crate::infer::{Entry, InferQueue}; use crate::validation::{ - ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, + ValidGrammar, ValidParameters, ValidStoppingParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; @@ -15,23 +15,6 @@ use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; -/// Queue entry -#[derive(Debug)] -pub(crate) struct Entry { - /// Request - pub request: ValidGenerateRequest, - /// Response sender to communicate between the Infer struct and the batching_task - pub response_tx: mpsc::UnboundedSender>, - /// Span that will live as long as entry - pub span: Span, - /// Temporary span used as a guard when logging inference, wait times... - pub temp_span: Option, - /// Instant when this entry was queued - pub queue_time: Instant, - /// Instant when this entry was added to a batch - pub batch_time: Option, -} - /// Request Queue #[derive(Debug, Clone)] pub(crate) struct Queue { @@ -39,6 +22,19 @@ pub(crate) struct Queue { queue_sender: mpsc::UnboundedSender, } + +impl InferQueue for Queue { + /// Append an entry to the queue + #[instrument(skip_all)] + fn append(&self, entry: Entry) { + // Send append command to the background task managing the state + // Unwrap is safe here + self.queue_sender + .send(QueueCommand::Append(Box::new(entry), Span::current())) + .unwrap(); + } +} + impl Queue { pub(crate) fn new( requires_padding: bool, @@ -61,16 +57,6 @@ impl Queue { Self { queue_sender } } - /// Append an entry to the queue - #[instrument(skip_all)] - pub(crate) fn append(&self, entry: Entry) { - // Send append command to the background task managing the state - // Unwrap is safe here - self.queue_sender - .send(QueueCommand::Append(Box::new(entry), Span::current())) - .unwrap(); - } - // Get the next batch #[instrument(skip(self))] pub(crate) async fn next_batch( diff --git a/router/src/server.rs b/router/src/server.rs index 3a3e1350..2ca49bc7 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,8 @@ /// HTTP Server logic + use crate::config::Config; use crate::infer::HealthCheck; -use crate::infer::v2::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; +use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -34,7 +35,7 @@ use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; -use text_generation_client::{v2::ShardedClient, ShardInfo, ClientError}; +use text_generation_client::{v2::ShardedClient, ClientError}; use tokenizers::Tokenizer; use tokio::select; use tokio::signal; @@ -372,7 +373,7 @@ async fn generate_stream( Json(req): Json, ) -> ( HeaderMap, - Sse>>, + Sse>>, ) { let span = tracing::Span::current(); let on_message_callback = |stream_token: StreamResponse| { @@ -391,7 +392,7 @@ async fn generate_stream_internal( Json(req): Json, on_message_callback: impl Fn(StreamResponse) -> Event, span: tracing::Span, -) -> (HeaderMap, impl Stream>) { +) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); @@ -559,38 +560,38 @@ async fn generate_stream_internal( /// Generate tokens #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/v1/completions", - request_body = CompletionRequest, - responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = Completion), - ("text/event-stream" = CompletionCompleteChunk), - )), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/v1/completions", +request_body = CompletionRequest, +responses( +(status = 200, description = "Generated Chat Completion", +content( +("application/json" = Completion), +("text/event-stream" = CompletionCompleteChunk), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( - skip_all, - fields( - // parameters = ? req.parameters, - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) - )] +skip_all, +fields( +// parameters = ? req.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) +)] async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, @@ -726,7 +727,7 @@ async fn completions( on_message_callback, span_clone.clone(), ) - .await; + .await; // send and dont wait for response let _ = header_tx.send(header_map); @@ -833,7 +834,7 @@ async fn completions( Json(generate_request), span_clone, ) - .await; + .await; result.map(|(headers, generation)| (index, headers, generation)) }; responses.push(response_future); @@ -964,38 +965,38 @@ async fn completions( /// Generate tokens #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/v1/chat/completions", - request_body = ChatRequest, - responses( - (status = 200, description = "Generated Chat Completion", - content( - ("application/json" = ChatCompletion), - ("text/event-stream" = ChatCompletionChunk), - )), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/v1/chat/completions", +request_body = ChatRequest, +responses( +(status = 200, description = "Generated Chat Completion", +content( +("application/json" = ChatCompletion), +("text/event-stream" = ChatCompletionChunk), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( - skip_all, - fields( - // parameters = ? req.parameters, - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) - )] +skip_all, +fields( +// parameters = ? req.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) +)] async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, @@ -1150,7 +1151,7 @@ async fn chat_completions( on_message_callback, span, ) - .await; + .await; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1220,32 +1221,32 @@ async fn chat_completions( /// Generate tokens from Vertex request #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/vertex", - request_body = VertexRequest, - responses( - (status = 200, description = "Generated Text", body = VertexResponse), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/vertex", +request_body = VertexRequest, +responses( +(status = 200, description = "Generated Text", body = VertexResponse), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] #[instrument( - skip_all, - fields( - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) +skip_all, +fields( +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) )] async fn vertex_compatibility( Extension(infer): Extension, @@ -1290,17 +1291,17 @@ async fn vertex_compatibility( Json(generate_request), span.clone(), ) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) } }) .collect::>() @@ -1313,16 +1314,16 @@ async fn vertex_compatibility( /// Tokenize inputs #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/tokenize", - request_body = GenerateRequest, - responses( - (status = 200, description = "Tokenized ids", body = TokenizeResponse), - (status = 404, description = "No tokenizer found", body = ErrorResponse, - example = json ! ({"error": "No fast tokenizer available"})), - ) - )] +post, +tag = "Text Generation Inference", +path = "/tokenize", +request_body = GenerateRequest, +responses( +(status = 200, description = "Tokenized ids", body = TokenizeResponse), +(status = 404, description = "No tokenizer found", body = ErrorResponse, +example = json ! ({"error": "No fast tokenizer available"})), +) +)] #[instrument(skip_all)] async fn tokenize( Extension(infer): Extension, @@ -1471,88 +1472,99 @@ pub async fn run( )] struct ApiDoc; - // Instantiate sharded client from the master unix socket - let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) - .await - .map_err(WebServerError::Connection)?; - // Clear the cache; useful if the webserver rebooted - sharded_client - .clear_cache(None) - .await - .map_err(WebServerError::Cache)?; - // Get info from the shard - let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; - - // Warmup model - tracing::info!("Warming up model"); - let max_batch_total_tokens = match sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(WebServerError::Warmup)? - { - // Older models do not support automatic max-batch-total-tokens - None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); - tracing::warn!("Model does not support automatic max batch total tokens"); - max_batch_total_tokens - } - // Flash attention models return their max supported total tokens - Some(max_supported_batch_total_tokens) => { - // Warn if user added his own max-batch-total-tokens as we will ignore it - if max_batch_total_tokens.is_some() { - tracing::warn!( + // Open connection, get model info and warmup + let (infer, health_ext, shard_info, max_batch_total_tokens) = { + // Helper function to check both v2 and v3 + let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + match max_supported_batch_total_tokens { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens + .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + tracing::warn!("Model does not support automatic max batch total tokens"); + Ok(max_batch_total_tokens) + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( "`--max-batch-total-tokens` is deprecated for Flash \ Attention models." ); - tracing::warn!( + tracing::warn!( "Inferred max batch total tokens: {max_supported_batch_total_tokens}" ); - } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(WebServerError::NotEnoughMemory(max_total_tokens)) - } + } + if max_total_tokens as u32 > max_supported_batch_total_tokens { + return Err(WebServerError::NotEnoughMemory(max_total_tokens)); + } - max_supported_batch_total_tokens - } + Ok(max_supported_batch_total_tokens) + } + } + }; + + // Create state + let validation = Validation::new( + validation_workers, + tokenizer, + config, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + grammar_support, + ); + let generation_health = Arc::new(AtomicBool::new(false)); + + // Try to open a v3 client + // Instantiate sharded client from the master unix socket + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + .await + .map_err(WebServerError::Connection)?; + // Clear the cache; useful if the webserver rebooted + sharded_client + .clear_cache(None) + .await + .map_err(WebServerError::Cache)?; + // Get info from the shard + let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?; + + // Warmup model + tracing::info!("Warming up model"); + let max_batch_total_tokens = check_max_batch_total_tokens(sharded_client + .warmup( + max_input_tokens as u32, + max_batch_prefill_tokens, + max_total_tokens as u32, + max_batch_size, + ) + .await + .map_err(WebServerError::Warmup)?)?; + tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); + + let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); + let infer = Infer::new( + sharded_client, + validation, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + max_concurrent_requests, + shard_info.requires_padding, + shard_info.window_size, + shard_info.speculate, + generation_health, + tokenizer_config, + processor_config, + ); + + (infer, health_ext, shard_info, max_batch_total_tokens) }; - tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); - - // Create state - let validation = Validation::new( - validation_workers, - tokenizer, - config, - max_best_of, - max_stop_sequences, - max_top_n_tokens, - max_input_tokens, - max_total_tokens, - grammar_support, - ); - let generation_health = Arc::new(AtomicBool::new(false)); - let health_ext = HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone()); - let infer = Infer::new( - sharded_client, - validation, - waiting_served_ratio, - max_batch_prefill_tokens, - max_batch_total_tokens, - max_waiting_tokens, - max_batch_size, - max_concurrent_requests, - shard_info.requires_padding, - shard_info.window_size, - shard_info.speculate, - generation_health, - tokenizer_config, - processor_config, - ); // Duration buckets let duration_matcher = Matcher::Suffix(String::from("duration")); @@ -1645,8 +1657,8 @@ pub async fn run( #[derive(OpenApi)] #[openapi( - paths(vertex_compatibility), - components(schemas(VertexInstance, VertexRequest, VertexResponse)) + paths(vertex_compatibility), + components(schemas(VertexInstance, VertexRequest, VertexResponse)) )] struct VertextApiDoc; @@ -1756,7 +1768,7 @@ async fn shutdown_signal() { }; #[cfg(unix)] - let terminate = async { + let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() @@ -1764,7 +1776,7 @@ async fn shutdown_signal() { }; #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, @@ -1821,5 +1833,5 @@ pub enum WebServerError { #[error("Not enough memory to handle `max_total_tokens={0}`")] NotEnoughMemory(usize), #[error("Axum error: {0}")] - Axum(#[from] axum::BoxError) + Axum(#[from] axum::BoxError), } \ No newline at end of file