diff --git a/router/src/infer.rs b/router/src/infer.rs index 472b7d66..424c493b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, - Message, PrefillToken, Queue, Token, + ChatTemplateInputs, Entry, FillInMiddleInputs, GenerateRequest, GenerateStreamResponse, + HubTokenizerConfig, Message, PrefillToken, Queue, Token, }; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; @@ -33,6 +33,8 @@ pub struct Infer { shared: Arc, /// Chat template chat_template: Option, + /// Fill in middle template + fill_in_middle_template: Option, /// Inference limit limit_concurrent_requests: Arc, } @@ -88,6 +90,10 @@ impl Infer { .chat_template .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); + let fill_in_middle_template = tokenizer_config + .fill_in_middle_template + .map(FillInMiddleTemplate::new); + // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -96,6 +102,7 @@ impl Infer { queue, shared, chat_template, + fill_in_middle_template, limit_concurrent_requests: semaphore, } } @@ -186,6 +193,25 @@ impl Infer { }) } + /// Apply the fill in the middle template to the request + #[instrument(skip_all)] + pub(crate) fn apply_fill_in_middle_template( + &self, + prompt: String, + prefix: Option, + suffix: Option, + ) -> Result { + self.fill_in_middle_template + .as_ref() + .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? + .apply(prompt, prefix, suffix) + .map_err(|e| { + metrics::increment_counter!("tgi_request_failure", "err" => "template"); + tracing::error!("{e}"); + e + }) + } + /// Add a new request to the queue and return a InferResponse #[instrument(skip_all)] pub(crate) async fn generate( @@ -342,6 +368,40 @@ impl ChatTemplate { } } +#[derive(Clone)] +struct FillInMiddleTemplate { + template: Template<'static, 'static>, +} + +impl FillInMiddleTemplate { + fn new(template: String) -> Self { + let mut env = Box::new(Environment::new()); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { template } + } + + fn apply( + &self, + prompt: String, + prefix: Option, + suffix: Option, + ) -> Result { + self.template + .render(FillInMiddleInputs { + prefix: prefix.as_deref(), + prompt: prompt.as_str(), + suffix: suffix.as_deref(), + }) + .map_err(InferError::TemplateError) + } +} + /// Batching logic /// Will be launched in a background Tokio task /// diff --git a/router/src/lib.rs b/router/src/lib.rs index ae6e7be5..bdb01d95 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -51,6 +51,7 @@ pub struct HubModelInfo { #[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, + pub fill_in_middle_template: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, #[serde(deserialize_with = "token_serde::deserialize")] @@ -615,6 +616,13 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct FillInMiddleInputs<'a> { + prefix: Option<&'a str>, + prompt: &'a str, + suffix: Option<&'a str>, +} + #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] diff --git a/router/src/server.rs b/router/src/server.rs index e9884781..df0bc1e7 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -570,15 +570,28 @@ async fn completions( Json(req): Json, ) -> Result)> { metrics::increment_counter!("tgi_request_count"); - let max_new_tokens = req.max_tokens.or(Some(100)); let stream = req.stream.unwrap_or_default(); let seed = req.seed; - let suffix = req.suffix.unwrap_or_default(); + + let inputs = match infer.apply_fill_in_middle_template(req.prompt, None, req.suffix) { + Ok(inputs) => inputs, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: err.to_string(), + error_type: err.error_type().to_string(), + }), + )); + } + }; // build the request passing some parameters let generate_request = GenerateRequest { - inputs: req.prompt.to_string(), + inputs: inputs.to_string(), parameters: GenerateParameters { best_of: None, temperature: req.temperature, @@ -685,7 +698,7 @@ async fn completions( finish_reason: details.finish_reason.to_string(), index: 0, logprobs: None, - text: generation.generated_text + &suffix, + text: generation.generated_text, }], usage: Usage { prompt_tokens: details.prefill.len() as u32,