mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: remove ChatTemplateError and add index to stream messages
This commit is contained in:
parent
3ae9cd655d
commit
ddf7412a6b
@ -140,24 +140,20 @@ impl Infer {
|
||||
|
||||
/// Apply the chat template to the chat request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
chat: ChatRequest,
|
||||
) -> Result<String, ChatTemplateError> {
|
||||
|
||||
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
||||
let mut env = minijinja::Environment::new();
|
||||
let chat_template = self
|
||||
.tokenizer_config
|
||||
.chat_template
|
||||
.as_ref()
|
||||
.ok_or(ChatTemplateError::TemplateNotFound)?;
|
||||
env.add_template("_", chat_template)
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
||||
let jinja_tmpl = env
|
||||
.get_template("_")
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))?;
|
||||
jinja_tmpl
|
||||
.ok_or_else(|| {
|
||||
InferError::TemplateError(minijinja::ErrorKind::TemplateNotFound.into())
|
||||
})?;
|
||||
env.add_template("_", chat_template)?;
|
||||
env.get_template("_")?
|
||||
.render(chat)
|
||||
.map_err(|e| ChatTemplateError::TemplateError(e))
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
@ -570,9 +566,9 @@ fn send_responses(
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs.into_iter())
|
||||
.zip(tokens_.texts.into_iter())
|
||||
.zip(tokens_.is_special.into_iter())
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
@ -681,6 +677,8 @@ pub enum InferError {
|
||||
ValidationError(#[from] ValidationError),
|
||||
#[error("Incomplete generation")]
|
||||
IncompleteGeneration,
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
@ -690,23 +688,7 @@ impl InferError {
|
||||
InferError::Overloaded(_) => "overloaded",
|
||||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ChatTemplateError {
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Template not found")]
|
||||
TemplateNotFound,
|
||||
}
|
||||
|
||||
impl ChatTemplateError {
|
||||
pub(crate) fn error_type(&self) -> &str {
|
||||
match self {
|
||||
ChatTemplateError::TemplateError(_) => "template_error",
|
||||
ChatTemplateError::TemplateNotFound => "template_not_found",
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -343,7 +344,7 @@ async fn generate_stream(
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, Json(req.into()), on_message_callback).await;
|
||||
generate_stream_internal(infer, Json(req), on_message_callback).await;
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
(headers, sse)
|
||||
}
|
||||
@ -547,7 +548,7 @@ async fn generate_stream_internal(
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn chat(
|
||||
async fn chat_completions(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
@ -557,7 +558,7 @@ async fn chat(
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = match req.max_tokens {
|
||||
Some(max_new_tokens) => Some(max_new_tokens),
|
||||
None => Some(100)
|
||||
None => Some(100),
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
@ -604,8 +605,9 @@ async fn chat(
|
||||
|
||||
// switch on stream
|
||||
if stream {
|
||||
let stream_count = AtomicU32::new(0);
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -613,11 +615,15 @@ async fn chat(
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
// increment the stream count
|
||||
stream_count.fetch_add(1, Ordering::SeqCst);
|
||||
let current_stream_count = stream_count.load(Ordering::SeqCst);
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
stream_token.token.text,
|
||||
current_time,
|
||||
0,
|
||||
current_stream_count,
|
||||
))
|
||||
.unwrap_or_else(|_| {
|
||||
println!("Failed to serialize ChatCompletionChunk");
|
||||
@ -843,7 +849,7 @@ pub async fn run(
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// AWS Sagemaker route
|
||||
.route("/invocations", post(compat_generate))
|
||||
// Base Health route
|
||||
@ -973,6 +979,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
};
|
||||
|
||||
(
|
||||
|
Loading…
Reference in New Issue
Block a user