diff --git a/router/src/infer.rs b/router/src/infer.rs index bc99e72e..e917f68f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,6 +4,7 @@ use crate::HubTokenizerConfig; use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; use crate::{Entry, Queue, Token}; use futures::future::try_join_all; +use minijinja::{Environment, ErrorKind, Template}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -27,12 +28,12 @@ pub struct Infer { validation: Validation, /// Request queue queue: Queue, - /// Chat formatter - tokenizer_config: HubTokenizerConfig, /// Shared state shared: Arc, /// Inference limit limit_concurrent_requests: Arc, + /// Chat template + template: Option>, } /// Infer shared state @@ -78,12 +79,21 @@ impl Infer { // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + let template = tokenizer_config.chat_template.map(|t| { + let env = Box::new(Environment::new()); + let template_str = t.into_boxed_str(); + // leaking env and template_str as read-only, static resources for performance. + Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap() + }); + Self { validation, queue, shared, limit_concurrent_requests: semaphore, - tokenizer_config, + template, } } @@ -140,9 +150,15 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result { - self.tokenizer_config - .apply_chat_template(chat) - .map_err(InferError::TemplateError) + self.template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .render(chat) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + InferError::TemplateError(e) + }) } /// Add a new request to the queue and return a InferResponse diff --git a/router/src/lib.rs b/router/src/lib.rs index 7df15630..c756065e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -36,22 +36,6 @@ pub struct HubTokenizerConfig { pub chat_template: Option, } -impl HubTokenizerConfig { - /// Apply the chat template to the chat request - pub(crate) fn apply_chat_template( - &self, - chat: ChatRequest, - ) -> Result { - let mut env = minijinja::Environment::new(); - let chat_template = self - .chat_template - .as_ref() - .ok_or(minijinja::ErrorKind::TemplateNotFound)?; - env.add_template("_", chat_template)?; - env.get_template("_")?.render(chat) - } -} - #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -292,7 +276,7 @@ impl ChatCompletionChunk { finish_reason: Option, ) -> Self { Self { - id: "".to_string(), + id: String::new(), object: "text_completion".to_string(), created, model, @@ -312,7 +296,7 @@ impl ChatCompletionChunk { fn default_request_messages() -> Vec { vec![Message { - role: "system".to_string(), + role: "user".to_string(), content: "My name is David and I".to_string(), }] } @@ -371,11 +355,14 @@ pub(crate) struct ChatRequest { #[serde(default = "bool::default")] pub stream: bool, + + #[schema(nullable = true, example = 42)] + pub seed: Option, } #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { - #[schema(example = "system")] + #[schema(example = "user")] pub role: String, #[schema(example = "My name is David and I")] pub content: String, diff --git a/router/src/server.rs b/router/src/server.rs index c5d2deba..a51c9033 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -569,6 +569,7 @@ async fn chat_completions( // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) .map(|x| x + 2.0); let logprobs = req.logprobs.unwrap_or(false); + let seed = req.seed; // apply chat template to flatten the request into a single input let inputs = match infer.apply_chat_template(req) { @@ -604,7 +605,7 @@ async fn chat_completions( watermark: false, details: true, decoder_input_details: false, - seed: None, + seed, top_n_tokens: None, }, };