mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
feat: add chat template struct to avoid tuple ordering errors (#1570)
This commit is contained in:
parent
31b5e37f49
commit
cf946b3984
@ -33,14 +33,10 @@ pub struct Infer {
|
|||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Shared state
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
/// Chat template
|
||||||
|
chat_template: Option<ChatTemplate>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
/// Chat template (template, bos_token, eos_token)
|
|
||||||
template: (
|
|
||||||
Option<Template<'static, 'static>>,
|
|
||||||
Option<String>,
|
|
||||||
Option<String>,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
@ -99,32 +95,19 @@ impl Infer {
|
|||||||
generation_health,
|
generation_health,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
let chat_template = tokenizer_config
|
||||||
|
.chat_template
|
||||||
|
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
let template = tokenizer_config.chat_template.map(|t| {
|
|
||||||
let mut env = Box::new(Environment::new());
|
|
||||||
let template_str = t.into_boxed_str();
|
|
||||||
env.add_function("raise_exception", raise_exception);
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
|
||||||
Box::leak(env)
|
|
||||||
.template_from_str(Box::leak(template_str))
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
let eos_token = tokenizer_config
|
|
||||||
.eos_token
|
|
||||||
.map_or_else(String::new, |t| t)
|
|
||||||
.into();
|
|
||||||
let bos_token = tokenizer_config
|
|
||||||
.bos_token
|
|
||||||
.map_or_else(String::new, |t| t)
|
|
||||||
.into();
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
|
chat_template,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
template: (template, bos_token, eos_token),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,20 +186,14 @@ impl Infer {
|
|||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
let (template, bos_token, eos_token) = &self.template;
|
self.chat_template
|
||||||
template
|
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.render(ChatTemplateInputs {
|
.apply(messages)
|
||||||
messages,
|
|
||||||
eos_token: eos_token.as_deref(),
|
|
||||||
bos_token: bos_token.as_deref(),
|
|
||||||
add_generation_prompt: true,
|
|
||||||
})
|
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
InferError::TemplateError(e)
|
e
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,6 +317,42 @@ impl Infer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ChatTemplate {
|
||||||
|
template: Template<'static, 'static>,
|
||||||
|
bos_token: Option<String>,
|
||||||
|
eos_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatTemplate {
|
||||||
|
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
|
let mut env = Box::new(Environment::new());
|
||||||
|
let template_str = template.into_boxed_str();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
|
let template = Box::leak(env)
|
||||||
|
.template_from_str(Box::leak(template_str))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
template,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
|
self.template
|
||||||
|
.render(ChatTemplateInputs {
|
||||||
|
messages,
|
||||||
|
bos_token: self.bos_token.as_deref(),
|
||||||
|
eos_token: self.eos_token.as_deref(),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
})
|
||||||
|
.map_err(InferError::TemplateError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
|
Loading…
Reference in New Issue
Block a user