diff --git a/router/src/server.rs b/router/src/server.rs index 73b54321..94d89703 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use pyo3::types::IntoPyDict; +use regex::Regex; use serde_json::Value; use std::convert::Infallible; use std::fs::File; @@ -452,13 +453,27 @@ async fn generate_stream( Sse>>, ) { let span = tracing::Span::current(); - let on_message_callback = |stream_token: StreamResponse| { - let event = Event::default(); - event.json_data(stream_token).unwrap() - }; let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; - let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + generate_stream_internal(infer, compute_type, Json(req), span).await; + + let final_response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + while let Some(raw_event) = response_stream.next().await { + match raw_event { + Ok(stream_token) => { + let event = Event::default(); + let event = event.json_data(stream_token).unwrap(); + yield Ok(event); + } + Err(_err) => { + let event = Event::default(); + yield Ok(event); + } + } + } + }; + + let sse = Sse::new(final_response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -466,9 +481,11 @@ async fn generate_stream_internal( infer: Infer, ComputeType(compute_type): ComputeType, Json(req): Json, - on_message_callback: impl Fn(StreamResponse) -> Event, span: tracing::Span, -) -> (HeaderMap, impl Stream>) { +) -> ( + HeaderMap, + impl Stream>, +) { let start_time = Instant::now(); metrics::counter!("tgi_request_count").increment(1); @@ -500,12 +517,12 @@ async fn generate_stream_internal( let err = InferError::from(ValidationError::BestOfStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives @@ -535,8 +552,7 @@ async fn generate_stream_internal( generated_text: None, details: None, }; - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -600,9 +616,7 @@ async fn generate_stream_internal( details }; - - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); break; } } @@ -610,7 +624,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); break; } } @@ -619,7 +633,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); } } // Check if generation reached the end @@ -628,7 +642,7 @@ async fn generate_stream_internal( let err = InferError::IncompleteGenerationStream; metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } } }; @@ -771,75 +785,88 @@ async fn completions( // Create a future for each generate_stream_internal call. let generate_future = async move { - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let message = match stream_token.details { - Some(details) => { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - - Completion::Final(CompletionFinal { - id: String::new(), - created: current_time, - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - choices: vec![CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens, - }, - }) - } - None => Completion::Chunk(Chunk { - id: String::new(), - created: current_time, - choices: vec![CompletionComplete { - finish_reason: String::new(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - }), - }; - - event - .json_data(message) - .unwrap_or_else(|_e| Event::default()) - }; - let (header_tx, header_rx) = oneshot::channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { - let (header_map, sse) = generate_stream_internal( + let (headers, response_stream) = generate_stream_internal( infer_clone.clone(), compute_type_clone.clone(), Json(generate_request), - on_message_callback, span_clone.clone(), ) .await; + let final_response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + + while let Some(stream_token) = response_stream.next().await { + match stream_token { + Ok(stream_token) => { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, + choices: vec![CompletionComplete { + finish_reason: String::new(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }), + }; + + let event = event + .json_data(message) + .unwrap_or_else(|_e| Event::default()); + + yield Ok(event); + } + Err(_err) => { + let event = Event::default(); + yield Ok(event); + } + } + } + }; + // send and dont wait for response - let _ = header_tx.send(header_map); + let _ = header_tx.send(headers); // pin an emit messages to the sse_tx - let mut sse = Box::pin(sse); + let mut sse = Box::pin(final_response_stream); while let Some(event) = sse.next().await { if sse_tx.send(event).is_err() { tracing::error!("Failed to send event. Receiver dropped."); @@ -1072,6 +1099,84 @@ async fn completions( } } +enum StreamState { + Buffering, + BufferTrailing, + Content { skip_close_quote: bool }, +} + +/// Convert a StreamResponse into an Event to be sent over SSE +fn create_event_from_stream_token( + stream_token: &StreamResponse, + logprobs: bool, + stream_options: Option, + inner_using_tools: bool, + system_fingerprint: String, + model_id: String, +) -> Event { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let logprobs = logprobs.then(|| { + ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone())) + }); + + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if inner_using_tools { + (None, Some(vec![stream_token.token.text.clone()])) + } else { + let content = if !stream_token.token.special { + Some(stream_token.token.text.clone()) + } else { + None + }; + + (content, None) + }; + + let (usage, finish_reason) = match &stream_token.details { + Some(details) => { + let usage = if stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) + { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Some(Usage { + completion_tokens, + prompt_tokens, + total_tokens, + }) + } else { + None + }; + (usage, Some(details.finish_reason.format(true))) + } + None => (None, None), + }; + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + finish_reason, + usage, + )); + + event.json_data(chat_complete).unwrap_or_else(|e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }) +} + /// Generate tokens #[utoipa::path( post, @@ -1128,90 +1233,160 @@ async fn chat_completions( // static values that will be returned in all cases let model_id = info.model_id.clone(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + let send_function_name = false; // TODO: fix to send function name // switch on stream if stream { - // pass this callback to the stream generation and build the required event structure - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); + let (headers, response_stream) = + generate_stream_internal(infer, compute_type, Json(generate_request), span).await; - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let logprobs = logprobs.then(|| { - ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) - }); - - // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if using_tools { - (None, Some(vec![stream_token.token.text])) + let final_response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + let mut buffer = Vec::new(); + let mut json_buffer = String::new(); + // let mut content_buffer = String::new(); + let mut state = if using_tools { + StreamState::Buffering } else { - let content = if !stream_token.token.special { - Some(stream_token.token.text) - } else { - None - }; - - (content, None) - }; - - let (usage, finish_reason) = match stream_token.details { - Some(details) => { - let usage = if stream_options - .as_ref() - .map(|s| s.include_usage) - .unwrap_or(false) - { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - Some(Usage { - completion_tokens, - prompt_tokens, - total_tokens, - }) - } else { - None - }; - (usage, Some(details.finish_reason.format(true))) + StreamState::Content { + skip_close_quote: false, } - None => (None, None), }; - event - .json_data(CompletionType::ChatCompletionChunk( - ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - finish_reason, - usage, - ), - )) - .unwrap_or_else(|e| { - println!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - }) + let mut response_as_tool = using_tools; + + // Regex to match any function name + let function_regex = Regex::new(r#"\{"function":\{"_name":"([^"]+)""#).unwrap(); + + while let Some(result) = response_stream.next().await { + match result { + Ok(stream_token) => { + let token_text = &stream_token.token.text.clone(); + + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "notify_error" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + + if send_function_name { + // send a message with the the function name + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let chat_complete = CompletionType::ChatCompletionChunk( + ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + None, + Some(vec![function_name.clone()]), + current_time, + None, + None, + None, + ), + ); + + let event = event.json_data(chat_complete).unwrap(); + yield Ok(event); + } + + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } + } + } + } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"error\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + if !json_buffer.contains(infix_text) { + continue; + } + + let error_index = json_buffer.find(infix_text).unwrap(); + json_buffer = + json_buffer[error_index + infix_text.len()..].to_string(); + + if json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + + let event = event.json_data(chat_complete).unwrap(); + yield Ok(event); + } + + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); + } + } + } + Err(_err) => { + yield Ok::(Event::default()); + break; + } + } + } + yield Ok::(Event::default().data("[DONE]")); }; - let (headers, response_stream) = generate_stream_internal( - infer, - compute_type, - Json(generate_request), - on_message_callback, - span, - ) - .await; - - let response_stream = response_stream.chain(futures::stream::once(async { - Ok(Event::default().data("[DONE]")) - })); - - let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + let sse = Sse::new(final_response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { let (headers, Json(generation)) = @@ -1246,17 +1421,33 @@ async fn chat_completions( if let Value::Object(ref mut props) = arguments { props.remove("_name"); } - - let tool_calls = vec![ToolCall { - id: "0".to_string(), - r#type: "function".to_string(), - function: FunctionDefinition { - description: None, - name, - arguments, - }, - }]; - (Some(tool_calls), None) + match name.as_str() { + "notify_error" => { + // parse the error message + let error_message = arguments + .get("error") + .and_then(Value::as_str) + .ok_or_else(|| { + InferError::ToolError( + "No error message found in generated text".to_string(), + ) + })? + .to_string(); + (None, Some(error_message)) + } + _ => { + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: FunctionDefinition { + description: None, + name, + arguments, + }, + }]; + (Some(tool_calls), None) + } + } } else { (None, Some(generation.generated_text)) };