feat: add chat template struct to avoid tuple ordering errors (#1570)

This commit is contained in:
OlivierDehaene 2024-02-16 16:37:32 +01:00 committed by Karol Damaszke
parent 31b5e37f49
commit cf946b3984

View File

@ -33,14 +33,10 @@ pub struct Infer {
queue: Queue, queue: Queue,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
/// Chat template
chat_template: Option<ChatTemplate>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Chat template (template, bos_token, eos_token)
template: (
Option<Template<'static, 'static>>,
Option<String>,
Option<String>,
),
} }
/// Infer shared state /// Infer shared state
@ -99,32 +95,19 @@ impl Infer {
generation_health, generation_health,
)); ));
let chat_template = tokenizer_config
.chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let template = tokenizer_config.chat_template.map(|t| {
let mut env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});
let eos_token = tokenizer_config
.eos_token
.map_or_else(String::new, |t| t)
.into();
let bos_token = tokenizer_config
.bos_token
.map_or_else(String::new, |t| t)
.into();
Self { Self {
validation, validation,
queue, queue,
shared, shared,
chat_template,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
template: (template, bos_token, eos_token),
} }
} }
@ -203,20 +186,14 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
let (template, bos_token, eos_token) = &self.template; self.chat_template
template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(ChatTemplateInputs { .apply(messages)
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
add_generation_prompt: true,
})
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}"); tracing::error!("{e}");
InferError::TemplateError(e) e
}) })
} }
@ -340,6 +317,42 @@ impl Infer {
} }
} }
#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
}
impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token,
eos_token,
}
}
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
})
.map_err(InferError::TemplateError)
}
}
/// Batching logic /// Batching logic
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///