fix: initialize chat template single time, fix defaults and add seed param

This commit is contained in:
drbh 2024-01-10 13:29:20 -05:00
parent 47ad7bfbe4
commit 55455a16c7
3 changed files with 30 additions and 26 deletions

View File

@ -4,6 +4,7 @@ use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -27,12 +28,12 @@ pub struct Infer {
validation: Validation, validation: Validation,
/// Request queue /// Request queue
queue: Queue, queue: Queue,
/// Chat formatter
tokenizer_config: HubTokenizerConfig,
/// Shared state /// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Chat template
template: Option<Template<'static, 'static>>,
} }
/// Infer shared state /// Infer shared state
@ -78,12 +79,21 @@ impl Infer {
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let template = tokenizer_config.chat_template.map(|t| {
let env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
tokenizer_config, template,
} }
} }
@ -140,9 +150,15 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.tokenizer_config self.template
.apply_chat_template(chat) .as_ref()
.map_err(InferError::TemplateError) .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(chat)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
InferError::TemplateError(e)
})
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse

View File

@ -36,22 +36,6 @@ pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<String>,
} }
impl HubTokenizerConfig {
/// Apply the chat template to the chat request
pub(crate) fn apply_chat_template(
&self,
chat: ChatRequest,
) -> Result<String, minijinja::Error> {
let mut env = minijinja::Environment::new();
let chat_template = self
.chat_template
.as_ref()
.ok_or(minijinja::ErrorKind::TemplateNotFound)?;
env.add_template("_", chat_template)?;
env.get_template("_")?.render(chat)
}
}
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info /// Model info
@ -292,7 +276,7 @@ impl ChatCompletionChunk {
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
Self { Self {
id: "".to_string(), id: String::new(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
created, created,
model, model,
@ -312,7 +296,7 @@ impl ChatCompletionChunk {
fn default_request_messages() -> Vec<Message> { fn default_request_messages() -> Vec<Message> {
vec![Message { vec![Message {
role: "system".to_string(), role: "user".to_string(),
content: "My name is David and I".to_string(), content: "My name is David and I".to_string(),
}] }]
} }
@ -371,11 +355,14 @@ pub(crate) struct ChatRequest {
#[serde(default = "bool::default")] #[serde(default = "bool::default")]
pub stream: bool, pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "system")] #[schema(example = "user")]
pub role: String, pub role: String,
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub content: String, pub content: String,

View File

@ -569,6 +569,7 @@ async fn chat_completions(
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0); .map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) { let inputs = match infer.apply_chat_template(req) {
@ -604,7 +605,7 @@ async fn chat_completions(
watermark: false, watermark: false,
details: true, details: true,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed,
top_n_tokens: None, top_n_tokens: None,
}, },
}; };