feat: use existing add_generation_prompt variable from config in template

This commit is contained in:
drbh 2024-02-05 10:01:40 -05:00
parent 0da00be52c
commit 76093c79ac
2 changed files with 10 additions and 2 deletions

View File

@ -38,6 +38,7 @@ pub struct Infer {
Option<Template<'static, 'static>>, Option<Template<'static, 'static>>,
Option<String>, Option<String>,
Option<String>, Option<String>,
bool,
), ),
} }
@ -106,12 +107,13 @@ impl Infer {
.bos_token .bos_token
.map_or_else(String::new, |t| t) .map_or_else(String::new, |t| t)
.into(); .into();
let add_generation_prompt = tokenizer_config.add_generation_prompt.unwrap_or(false);
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
template: (template, eos_token, bos_token), template: (template, eos_token, bos_token, add_generation_prompt),
} }
} }
@ -190,7 +192,7 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
let (template, bos_token, eos_token) = &self.template; let (template, bos_token, eos_token, add_generation_prompt) = &self.template;
template template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
@ -198,6 +200,7 @@ impl Infer {
messages, messages,
eos_token: eos_token.as_deref(), eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(), bos_token: bos_token.as_deref(),
add_generation_prompt: *add_generation_prompt,
}) })
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
@ -806,6 +809,7 @@ mod tests {
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false,
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -878,6 +882,7 @@ magic!"#
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false,
}; };
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -943,6 +948,7 @@ magic!"#
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"), eos_token: Some("[EOS]"),
add_generation_prompt: false,
}; };
let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap();

View File

@ -34,6 +34,7 @@ pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<String>,
pub bos_token: Option<String>, pub bos_token: Option<String>,
pub eos_token: Option<String>, pub eos_token: Option<String>,
pub add_generation_prompt: Option<bool>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
@ -398,6 +399,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<Message>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]