mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: use existing add_generation_prompt variable from config in template
This commit is contained in:
parent
0da00be52c
commit
76093c79ac
@ -38,6 +38,7 @@ pub struct Infer {
|
|||||||
Option<Template<'static, 'static>>,
|
Option<Template<'static, 'static>>,
|
||||||
Option<String>,
|
Option<String>,
|
||||||
Option<String>,
|
Option<String>,
|
||||||
|
bool,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,12 +107,13 @@ impl Infer {
|
|||||||
.bos_token
|
.bos_token
|
||||||
.map_or_else(String::new, |t| t)
|
.map_or_else(String::new, |t| t)
|
||||||
.into();
|
.into();
|
||||||
|
let add_generation_prompt = tokenizer_config.add_generation_prompt.unwrap_or(false);
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
template: (template, eos_token, bos_token),
|
template: (template, eos_token, bos_token, add_generation_prompt),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +192,7 @@ impl Infer {
|
|||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
let (template, bos_token, eos_token) = &self.template;
|
let (template, bos_token, eos_token, add_generation_prompt) = &self.template;
|
||||||
template
|
template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
@ -198,6 +200,7 @@ impl Infer {
|
|||||||
messages,
|
messages,
|
||||||
eos_token: eos_token.as_deref(),
|
eos_token: eos_token.as_deref(),
|
||||||
bos_token: bos_token.as_deref(),
|
bos_token: bos_token.as_deref(),
|
||||||
|
add_generation_prompt: *add_generation_prompt,
|
||||||
})
|
})
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
@ -806,6 +809,7 @@ mod tests {
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
@ -878,6 +882,7 @@ magic!"#
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||||
@ -943,6 +948,7 @@ magic!"#
|
|||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
@ -34,6 +34,7 @@ pub struct HubTokenizerConfig {
|
|||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
pub eos_token: Option<String>,
|
pub eos_token: Option<String>,
|
||||||
|
pub add_generation_prompt: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
@ -398,6 +399,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
Loading…
Reference in New Issue
Block a user