diff --git a/router/src/infer.rs b/router/src/infer.rs index ddd3c763..a61331d5 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -34,7 +34,11 @@ pub struct Infer { /// Inference limit limit_concurrent_requests: Arc, /// Chat template (template, bos_token, eos_token) - template: (Option>, String, String), + template: ( + Option>, + Option, + Option, + ), } /// Infer shared state @@ -92,18 +96,20 @@ impl Infer { .template_from_str(Box::leak(template_str)) .unwrap() }); - + let eos_token = tokenizer_config + .eos_token + .map_or_else(String::new, |t| t) + .into(); + let bos_token = tokenizer_config + .bos_token + .map_or_else(String::new, |t| t) + .into(); Self { validation, queue, shared, limit_concurrent_requests: semaphore, - template: ( - template, - // initialize bos_token and eos_token to empty strings if not provided - tokenizer_config.bos_token.unwrap_or_default(), - tokenizer_config.eos_token.unwrap_or_default(), - ), + template: (template, eos_token, bos_token), } } @@ -160,14 +166,14 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { - let (template, bos_token, eos_token) = self.template.clone(); + let (template, bos_token, eos_token) = &self.template; template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .render(ChatTemplateInputs { messages, - eos_token, - bos_token, + eos_token: eos_token.as_deref(), + bos_token: bos_token.as_deref(), }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); @@ -773,8 +779,8 @@ mod tests { content: "magic!".to_string(), }, ], - bos_token: "[BOS]".to_string(), - eos_token: "[EOS]".to_string(), + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); @@ -852,8 +858,8 @@ magic!"# content: "magic!".to_string(), }, ], - bos_token: "[BOS]".to_string(), - eos_token: "[EOS]".to_string(), + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -924,8 +930,8 @@ magic!"# content: "magic!".to_string(), }, ], - bos_token: "[BOS]".to_string(), - eos_token: "[EOS]".to_string(), + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 213df7bf..983079d6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -368,10 +368,10 @@ pub(crate) struct ChatRequest { } #[derive(Clone, Serialize, Deserialize)] -pub(crate) struct ChatTemplateInputs { +pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, - bos_token: String, - eos_token: String, + bos_token: Option<&'a str>, + eos_token: Option<&'a str>, } #[derive(Clone, Deserialize, ToSchema, Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index e721e97d..530a935b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -659,9 +659,9 @@ async fn chat_completions( // build the complete response object with the full text let response = ChatCompletion::new( - generation.generated_text, model_id, system_fingerprint, + generation.generated_text, current_time, generation.details.unwrap(), logprobs,