mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: initialize chat template single time, fix defaults and add seed param
This commit is contained in:
parent
47ad7bfbe4
commit
55455a16c7
@ -4,6 +4,7 @@ use crate::HubTokenizerConfig;
|
|||||||
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
@ -27,12 +28,12 @@ pub struct Infer {
|
|||||||
validation: Validation,
|
validation: Validation,
|
||||||
/// Request queue
|
/// Request queue
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
/// Chat formatter
|
|
||||||
tokenizer_config: HubTokenizerConfig,
|
|
||||||
/// Shared state
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
|
/// Chat template
|
||||||
|
template: Option<Template<'static, 'static>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
@ -78,12 +79,21 @@ impl Infer {
|
|||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
|
let template = tokenizer_config.chat_template.map(|t| {
|
||||||
|
let env = Box::new(Environment::new());
|
||||||
|
let template_str = t.into_boxed_str();
|
||||||
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
|
Box::leak(env)
|
||||||
|
.template_from_str(Box::leak(template_str))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
tokenizer_config,
|
template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,9 +150,15 @@ impl Infer {
|
|||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
||||||
self.tokenizer_config
|
self.template
|
||||||
.apply_chat_template(chat)
|
.as_ref()
|
||||||
.map_err(InferError::TemplateError)
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
|
.render(chat)
|
||||||
|
.map_err(|e| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
|
tracing::error!("{e}");
|
||||||
|
InferError::TemplateError(e)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
|
@ -36,22 +36,6 @@ pub struct HubTokenizerConfig {
|
|||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
|
||||||
/// Apply the chat template to the chat request
|
|
||||||
pub(crate) fn apply_chat_template(
|
|
||||||
&self,
|
|
||||||
chat: ChatRequest,
|
|
||||||
) -> Result<String, minijinja::Error> {
|
|
||||||
let mut env = minijinja::Environment::new();
|
|
||||||
let chat_template = self
|
|
||||||
.chat_template
|
|
||||||
.as_ref()
|
|
||||||
.ok_or(minijinja::ErrorKind::TemplateNotFound)?;
|
|
||||||
env.add_template("_", chat_template)?;
|
|
||||||
env.get_template("_")?.render(chat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
@ -292,7 +276,7 @@ impl ChatCompletionChunk {
|
|||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: "".to_string(),
|
id: String::new(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
@ -312,7 +296,7 @@ impl ChatCompletionChunk {
|
|||||||
|
|
||||||
fn default_request_messages() -> Vec<Message> {
|
fn default_request_messages() -> Vec<Message> {
|
||||||
vec![Message {
|
vec![Message {
|
||||||
role: "system".to_string(),
|
role: "user".to_string(),
|
||||||
content: "My name is David and I".to_string(),
|
content: "My name is David and I".to_string(),
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
@ -371,11 +355,14 @@ pub(crate) struct ChatRequest {
|
|||||||
|
|
||||||
#[serde(default = "bool::default")]
|
#[serde(default = "bool::default")]
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "system")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
@ -569,6 +569,7 @@ async fn chat_completions(
|
|||||||
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||||
.map(|x| x + 2.0);
|
.map(|x| x + 2.0);
|
||||||
let logprobs = req.logprobs.unwrap_or(false);
|
let logprobs = req.logprobs.unwrap_or(false);
|
||||||
|
let seed = req.seed;
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(req) {
|
let inputs = match infer.apply_chat_template(req) {
|
||||||
@ -604,7 +605,7 @@ async fn chat_completions(
|
|||||||
watermark: false,
|
watermark: false,
|
||||||
details: true,
|
details: true,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user