From 76093c79ac090af7836066130c5e7ad0393a5fc9 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Feb 2024 10:01:40 -0500 Subject: [PATCH] feat: use existing add_generation_prompt variable from config in template --- router/src/infer.rs | 10 ++++++++-- router/src/lib.rs | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 5f078ba0..23c8655d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -38,6 +38,7 @@ pub struct Infer { Option>, Option, Option, + bool, ), } @@ -106,12 +107,13 @@ impl Infer { .bos_token .map_or_else(String::new, |t| t) .into(); + let add_generation_prompt = tokenizer_config.add_generation_prompt.unwrap_or(false); Self { validation, queue, shared, limit_concurrent_requests: semaphore, - template: (template, eos_token, bos_token), + template: (template, eos_token, bos_token, add_generation_prompt), } } @@ -190,7 +192,7 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { - let (template, bos_token, eos_token) = &self.template; + let (template, bos_token, eos_token, add_generation_prompt) = &self.template; template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? @@ -198,6 +200,7 @@ impl Infer { messages, eos_token: eos_token.as_deref(), bos_token: bos_token.as_deref(), + add_generation_prompt: *add_generation_prompt, }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); @@ -806,6 +809,7 @@ mod tests { ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: false, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); @@ -878,6 +882,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: false, }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -943,6 +948,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: false, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 07360e78..23b08a82 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -34,6 +34,7 @@ pub struct HubTokenizerConfig { pub chat_template: Option, pub bos_token: Option, pub eos_token: Option, + pub add_generation_prompt: Option, } impl HubTokenizerConfig { @@ -398,6 +399,7 @@ pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, + add_generation_prompt: bool, } #[derive(Clone, Deserialize, ToSchema, Serialize)]