fix: remove ChatTemplateError and add index to stream messages

This commit is contained in:
drbh 2024-01-08 08:52:01 -05:00
parent 3ae9cd655d
commit ddf7412a6b
2 changed files with 27 additions and 38 deletions

View File

@ -140,24 +140,20 @@ impl Infer {
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
chat: ChatRequest,
) -> Result<String, ChatTemplateError> {
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
let mut env = minijinja::Environment::new();
let chat_template = self
.tokenizer_config
.chat_template
.as_ref()
.ok_or(ChatTemplateError::TemplateNotFound)?;
env.add_template("_", chat_template)
.map_err(|e| ChatTemplateError::TemplateError(e))?;
let jinja_tmpl = env
.get_template("_")
.map_err(|e| ChatTemplateError::TemplateError(e))?;
jinja_tmpl
.ok_or_else(|| {
InferError::TemplateError(minijinja::ErrorKind::TemplateNotFound.into())
})?;
env.add_template("_", chat_template)?;
env.get_template("_")?
.render(chat)
.map_err(|e| ChatTemplateError::TemplateError(e))
.map_err(InferError::TemplateError)
}
/// Add a new request to the queue and return a InferResponse
@ -570,9 +566,9 @@ fn send_responses(
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
@ -681,6 +677,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
}
impl InferError {
@ -690,23 +688,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
}
}
}
#[derive(Debug, Error)]
pub enum ChatTemplateError {
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Template not found")]
TemplateNotFound,
}
impl ChatTemplateError {
pub(crate) fn error_type(&self) -> &str {
match self {
ChatTemplateError::TemplateError(_) => "template_error",
ChatTemplateError::TemplateNotFound => "template_not_found",
InferError::TemplateError(_) => "template_error",
}
}
}

View File

@ -21,6 +21,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer;
@ -343,7 +344,7 @@ async fn generate_stream(
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) =
generate_stream_internal(infer, Json(req.into()), on_message_callback).await;
generate_stream_internal(infer, Json(req), on_message_callback).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse)
}
@ -547,7 +548,7 @@ async fn generate_stream_internal(
seed,
)
)]
async fn chat(
async fn chat_completions(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
@ -557,7 +558,7 @@ async fn chat(
let stream = req.stream;
let max_new_tokens = match req.max_tokens {
Some(max_new_tokens) => Some(max_new_tokens),
None => Some(100)
None => Some(100),
};
// apply chat template to flatten the request into a single input
@ -604,8 +605,9 @@ async fn chat(
// switch on stream
if stream {
let stream_count = AtomicU32::new(0);
// pass this callback to the stream generation and build the required event structure
let on_message_callback = |stream_token: StreamResponse| {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
@ -613,11 +615,15 @@ async fn chat(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
// increment the stream count
stream_count.fetch_add(1, Ordering::SeqCst);
let current_stream_count = stream_count.load(Ordering::SeqCst);
event
.json_data(ChatCompletionChunk::new(
stream_token.token.text,
current_time,
0,
current_stream_count,
))
.unwrap_or_else(|_| {
println!("Failed to serialize ChatCompletionChunk");
@ -843,7 +849,7 @@ pub async fn run(
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat))
.route("/v1/chat/completions", post(chat_completions))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
@ -973,6 +979,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
};
(