From 2d79e7744aa0f6e5317c34ebe10783dede056789 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 20 Feb 2024 14:42:31 -0500 Subject: [PATCH] fix: prefer name completion over fill in middle --- router/src/infer.rs | 26 +++++++++++++------------- router/src/lib.rs | 4 ++-- router/src/server.rs | 6 +++++- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 424c493b..dae6f53b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, Entry, FillInMiddleInputs, GenerateRequest, GenerateStreamResponse, + ChatTemplateInputs, Entry, CompletionTemplateInputs, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, Message, PrefillToken, Queue, Token, }; use futures::future::try_join_all; @@ -33,8 +33,8 @@ pub struct Infer { shared: Arc, /// Chat template chat_template: Option, - /// Fill in middle template - fill_in_middle_template: Option, + /// Completion template + completion_template: Option, /// Inference limit limit_concurrent_requests: Arc, } @@ -90,9 +90,9 @@ impl Infer { .chat_template .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); - let fill_in_middle_template = tokenizer_config - .fill_in_middle_template - .map(FillInMiddleTemplate::new); + let completion_template = tokenizer_config + .completion_template + .map(CompletionTemplate::new); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -102,7 +102,7 @@ impl Infer { queue, shared, chat_template, - fill_in_middle_template, + completion_template, limit_concurrent_requests: semaphore, } } @@ -193,15 +193,15 @@ impl Infer { }) } - /// Apply the fill in the middle template to the request + /// Apply the completion template to the request #[instrument(skip_all)] - pub(crate) fn apply_fill_in_middle_template( + pub(crate) fn apply_completion_template( &self, prompt: String, prefix: Option, suffix: Option, ) -> Result { - self.fill_in_middle_template + self.completion_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .apply(prompt, prefix, suffix) @@ -369,11 +369,11 @@ impl ChatTemplate { } #[derive(Clone)] -struct FillInMiddleTemplate { +struct CompletionTemplate { template: Template<'static, 'static>, } -impl FillInMiddleTemplate { +impl CompletionTemplate { fn new(template: String) -> Self { let mut env = Box::new(Environment::new()); let template_str = template.into_boxed_str(); @@ -393,7 +393,7 @@ impl FillInMiddleTemplate { suffix: Option, ) -> Result { self.template - .render(FillInMiddleInputs { + .render(CompletionTemplateInputs { prefix: prefix.as_deref(), prompt: prompt.as_str(), suffix: suffix.as_deref(), diff --git a/router/src/lib.rs b/router/src/lib.rs index bdb01d95..c8578d8c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -51,7 +51,7 @@ pub struct HubModelInfo { #[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, - pub fill_in_middle_template: Option, + pub completion_template: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, #[serde(deserialize_with = "token_serde::deserialize")] @@ -617,7 +617,7 @@ pub(crate) struct ChatTemplateInputs<'a> { } #[derive(Clone, Serialize, Deserialize)] -pub(crate) struct FillInMiddleInputs<'a> { +pub(crate) struct CompletionTemplateInputs<'a> { prefix: Option<&'a str>, prompt: &'a str, suffix: Option<&'a str>, diff --git a/router/src/server.rs b/router/src/server.rs index c6cd3d5b..19c63776 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -577,7 +577,7 @@ async fn completions( let stream = req.stream.unwrap_or_default(); let seed = req.seed; - let inputs = match infer.apply_fill_in_middle_template(req.prompt, None, req.suffix) { + let inputs = match infer.apply_completion_template(req.prompt, None, req.suffix) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -1067,6 +1067,7 @@ pub async fn run( generate, generate_stream, chat_completions, + completions, tokenize, metrics, ), @@ -1081,6 +1082,9 @@ pub async fn run( ChatCompletionDelta, ChatCompletionChunk, ChatCompletion, + CompletionRequest, + CompletionComplete, + CompletionCompleteChunk, GenerateParameters, PrefillToken, Token,