fix: make eos and bos optional and only init template once

This commit is contained in:
drbh 2024-01-17 18:24:17 -05:00
parent db835509ed
commit f378c60517
3 changed files with 27 additions and 21 deletions

View File

@ -34,7 +34,11 @@ pub struct Infer {
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Chat template (template, bos_token, eos_token) /// Chat template (template, bos_token, eos_token)
template: (Option<Template<'static, 'static>>, String, String), template: (
Option<Template<'static, 'static>>,
Option<String>,
Option<String>,
),
} }
/// Infer shared state /// Infer shared state
@ -92,18 +96,20 @@ impl Infer {
.template_from_str(Box::leak(template_str)) .template_from_str(Box::leak(template_str))
.unwrap() .unwrap()
}); });
let eos_token = tokenizer_config
.eos_token
.map_or_else(String::new, |t| t)
.into();
let bos_token = tokenizer_config
.bos_token
.map_or_else(String::new, |t| t)
.into();
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
template: ( template: (template, eos_token, bos_token),
template,
// initialize bos_token and eos_token to empty strings if not provided
tokenizer_config.bos_token.unwrap_or_default(),
tokenizer_config.eos_token.unwrap_or_default(),
),
} }
} }
@ -160,14 +166,14 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
let (template, bos_token, eos_token) = self.template.clone(); let (template, bos_token, eos_token) = &self.template;
template template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
messages, messages,
eos_token, eos_token: eos_token.as_deref(),
bos_token, bos_token: bos_token.as_deref(),
}) })
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
@ -773,8 +779,8 @@ mod tests {
content: "magic!".to_string(), content: "magic!".to_string(),
}, },
], ],
bos_token: "[BOS]".to_string(), bos_token: Some("[BOS]"),
eos_token: "[EOS]".to_string(), eos_token: Some("[EOS]"),
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -852,8 +858,8 @@ magic!"#
content: "magic!".to_string(), content: "magic!".to_string(),
}, },
], ],
bos_token: "[BOS]".to_string(), bos_token: Some("[BOS]"),
eos_token: "[EOS]".to_string(), eos_token: Some("[EOS]"),
}; };
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -924,8 +930,8 @@ magic!"#
content: "magic!".to_string(), content: "magic!".to_string(),
}, },
], ],
bos_token: "[BOS]".to_string(), bos_token: Some("[BOS]"),
eos_token: "[EOS]".to_string(), eos_token: Some("[EOS]"),
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();

View File

@ -368,10 +368,10 @@ pub(crate) struct ChatRequest {
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub(crate) struct ChatTemplateInputs { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<Message>,
bos_token: String, bos_token: Option<&'a str>,
eos_token: String, eos_token: Option<&'a str>,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]

View File

@ -659,9 +659,9 @@ async fn chat_completions(
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
generation.generated_text,
model_id, model_id,
system_fingerprint, system_fingerprint,
generation.generated_text,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,