fix: remove ChatTemplateError and add index to stream messages

This commit is contained in:
drbh 2024-01-08 08:52:01 -05:00
parent 3ae9cd655d
commit ddf7412a6b
2 changed files with 27 additions and 38 deletions

View File

@ -140,24 +140,20 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self, pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
chat: ChatRequest,
) -> Result<String, ChatTemplateError> {
let mut env = minijinja::Environment::new(); let mut env = minijinja::Environment::new();
let chat_template = self let chat_template = self
.tokenizer_config .tokenizer_config
.chat_template .chat_template
.as_ref() .as_ref()
.ok_or(ChatTemplateError::TemplateNotFound)?; .ok_or_else(|| {
env.add_template("_", chat_template) InferError::TemplateError(minijinja::ErrorKind::TemplateNotFound.into())
.map_err(|e| ChatTemplateError::TemplateError(e))?; })?;
let jinja_tmpl = env env.add_template("_", chat_template)?;
.get_template("_") env.get_template("_")?
.map_err(|e| ChatTemplateError::TemplateError(e))?;
jinja_tmpl
.render(chat) .render(chat)
.map_err(|e| ChatTemplateError::TemplateError(e)) .map_err(InferError::TemplateError)
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
@ -570,9 +566,9 @@ fn send_responses(
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
.into_iter() .into_iter()
.zip(tokens_.logprobs.into_iter()) .zip(tokens_.logprobs)
.zip(tokens_.texts.into_iter()) .zip(tokens_.texts)
.zip(tokens_.is_special.into_iter()) .zip(tokens_.is_special)
.enumerate() .enumerate()
.peekable(); .peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() { while let Some((i, (((id, logprob), text), special))) = iterator.next() {
@ -681,6 +677,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError), ValidationError(#[from] ValidationError),
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
} }
impl InferError { impl InferError {
@ -690,23 +688,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded", InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
} InferError::TemplateError(_) => "template_error",
}
}
#[derive(Debug, Error)]
pub enum ChatTemplateError {
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Template not found")]
TemplateNotFound,
}
impl ChatTemplateError {
pub(crate) fn error_type(&self) -> &str {
match self {
ChatTemplateError::TemplateError(_) => "template_error",
ChatTemplateError::TemplateNotFound => "template_not_found",
} }
} }
} }

View File

@ -21,6 +21,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -343,7 +344,7 @@ async fn generate_stream(
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, Json(req.into()), on_message_callback).await; generate_stream_internal(infer, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -547,7 +548,7 @@ async fn generate_stream_internal(
seed, seed,
) )
)] )]
async fn chat( async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -557,7 +558,7 @@ async fn chat(
let stream = req.stream; let stream = req.stream;
let max_new_tokens = match req.max_tokens { let max_new_tokens = match req.max_tokens {
Some(max_new_tokens) => Some(max_new_tokens), Some(max_new_tokens) => Some(max_new_tokens),
None => Some(100) None => Some(100),
}; };
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
@ -604,8 +605,9 @@ async fn chat(
// switch on stream // switch on stream
if stream { if stream {
let stream_count = AtomicU32::new(0);
// pass this callback to the stream generation and build the required event structure // pass this callback to the stream generation and build the required event structure
let on_message_callback = |stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -613,11 +615,15 @@ async fn chat(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
// increment the stream count
stream_count.fetch_add(1, Ordering::SeqCst);
let current_stream_count = stream_count.load(Ordering::SeqCst);
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
stream_token.token.text, stream_token.token.text,
current_time, current_time,
0, current_stream_count,
)) ))
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
println!("Failed to serialize ChatCompletionChunk"); println!("Failed to serialize ChatCompletionChunk");
@ -843,7 +849,7 @@ pub async fn run(
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat)) .route("/v1/chat/completions", post(chat_completions))
// AWS Sagemaker route // AWS Sagemaker route
.route("/invocations", post(compat_generate)) .route("/invocations", post(compat_generate))
// Base Health route // Base Health route
@ -973,6 +979,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
( (