feat: defaults add_generation_prompt true

This commit is contained in:
drbh 2024-02-06 09:08:35 -05:00
parent 76093c79ac
commit ff0428a351
2 changed files with 3 additions and 6 deletions

View File

@ -38,7 +38,6 @@ pub struct Infer {
Option<Template<'static, 'static>>,
Option<String>,
Option<String>,
bool,
),
}
@ -107,13 +106,12 @@ impl Infer {
.bos_token
.map_or_else(String::new, |t| t)
.into();
let add_generation_prompt = tokenizer_config.add_generation_prompt.unwrap_or(false);
Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
template: (template, eos_token, bos_token, add_generation_prompt),
template: (template, eos_token, bos_token),
}
}
@ -192,7 +190,7 @@ impl Infer {
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
let (template, bos_token, eos_token, add_generation_prompt) = &self.template;
let (template, bos_token, eos_token) = &self.template;
template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
@ -200,7 +198,7 @@ impl Infer {
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
add_generation_prompt: *add_generation_prompt,
add_generation_prompt: true,
})
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");

View File

@ -34,7 +34,6 @@ pub struct HubTokenizerConfig {
pub chat_template: Option<String>,
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub add_generation_prompt: Option<bool>,
}
impl HubTokenizerConfig {