mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: remove ChatTemplateError and add index to stream messages
This commit is contained in:
parent
3ae9cd655d
commit
ddf7412a6b
@ -140,24 +140,20 @@ impl Infer {
|
|||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(
|
|
||||||
&self,
|
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
||||||
chat: ChatRequest,
|
|
||||||
) -> Result<String, ChatTemplateError> {
|
|
||||||
let mut env = minijinja::Environment::new();
|
let mut env = minijinja::Environment::new();
|
||||||
let chat_template = self
|
let chat_template = self
|
||||||
.tokenizer_config
|
.tokenizer_config
|
||||||
.chat_template
|
.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or(ChatTemplateError::TemplateNotFound)?;
|
.ok_or_else(|| {
|
||||||
env.add_template("_", chat_template)
|
InferError::TemplateError(minijinja::ErrorKind::TemplateNotFound.into())
|
||||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
})?;
|
||||||
let jinja_tmpl = env
|
env.add_template("_", chat_template)?;
|
||||||
.get_template("_")
|
env.get_template("_")?
|
||||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
|
||||||
jinja_tmpl
|
|
||||||
.render(chat)
|
.render(chat)
|
||||||
.map_err(|e| ChatTemplateError::TemplateError(e))
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
@ -570,9 +566,9 @@ fn send_responses(
|
|||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens_.logprobs.into_iter())
|
.zip(tokens_.logprobs)
|
||||||
.zip(tokens_.texts.into_iter())
|
.zip(tokens_.texts)
|
||||||
.zip(tokens_.is_special.into_iter())
|
.zip(tokens_.is_special)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.peekable();
|
.peekable();
|
||||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
@ -681,6 +677,8 @@ pub enum InferError {
|
|||||||
ValidationError(#[from] ValidationError),
|
ValidationError(#[from] ValidationError),
|
||||||
#[error("Incomplete generation")]
|
#[error("Incomplete generation")]
|
||||||
IncompleteGeneration,
|
IncompleteGeneration,
|
||||||
|
#[error("Template error: {0}")]
|
||||||
|
TemplateError(#[from] minijinja::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferError {
|
impl InferError {
|
||||||
@ -690,23 +688,7 @@ impl InferError {
|
|||||||
InferError::Overloaded(_) => "overloaded",
|
InferError::Overloaded(_) => "overloaded",
|
||||||
InferError::ValidationError(_) => "validation",
|
InferError::ValidationError(_) => "validation",
|
||||||
InferError::IncompleteGeneration => "incomplete_generation",
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
}
|
InferError::TemplateError(_) => "template_error",
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum ChatTemplateError {
|
|
||||||
#[error("Template error: {0}")]
|
|
||||||
TemplateError(#[from] minijinja::Error),
|
|
||||||
#[error("Template not found")]
|
|
||||||
TemplateNotFound,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatTemplateError {
|
|
||||||
pub(crate) fn error_type(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
ChatTemplateError::TemplateError(_) => "template_error",
|
|
||||||
ChatTemplateError::TemplateNotFound => "template_not_found",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
|||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::atomic::{AtomicU32, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -343,7 +344,7 @@ async fn generate_stream(
|
|||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, Json(req.into()), on_message_callback).await;
|
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
@ -547,7 +548,7 @@ async fn generate_stream_internal(
|
|||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn chat(
|
async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
@ -557,7 +558,7 @@ async fn chat(
|
|||||||
let stream = req.stream;
|
let stream = req.stream;
|
||||||
let max_new_tokens = match req.max_tokens {
|
let max_new_tokens = match req.max_tokens {
|
||||||
Some(max_new_tokens) => Some(max_new_tokens),
|
Some(max_new_tokens) => Some(max_new_tokens),
|
||||||
None => Some(100)
|
None => Some(100),
|
||||||
};
|
};
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
@ -604,8 +605,9 @@ async fn chat(
|
|||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
|
let stream_count = AtomicU32::new(0);
|
||||||
// pass this callback to the stream generation and build the required event structure
|
// pass this callback to the stream generation and build the required event structure
|
||||||
let on_message_callback = |stream_token: StreamResponse| {
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
@ -613,11 +615,15 @@ async fn chat(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
|
// increment the stream count
|
||||||
|
stream_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let current_stream_count = stream_count.load(Ordering::SeqCst);
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(ChatCompletionChunk::new(
|
||||||
stream_token.token.text,
|
stream_token.token.text,
|
||||||
current_time,
|
current_time,
|
||||||
0,
|
current_stream_count,
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|_| {
|
.unwrap_or_else(|_| {
|
||||||
println!("Failed to serialize ChatCompletionChunk");
|
println!("Failed to serialize ChatCompletionChunk");
|
||||||
@ -843,7 +849,7 @@ pub async fn run(
|
|||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
// AWS Sagemaker route
|
// AWS Sagemaker route
|
||||||
.route("/invocations", post(compat_generate))
|
.route("/invocations", post(compat_generate))
|
||||||
// Base Health route
|
// Base Health route
|
||||||
@ -973,6 +979,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
Loading…
Reference in New Issue
Block a user