From f82ff3f64a5d485ea87aa1b80cc8b964cd4136c9 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 9 Jan 2024 11:54:20 -0500 Subject: [PATCH] fix: adds index, model id, system fingerprint and updates do_sample param --- router/src/lib.rs | 20 ++++++++++++++------ router/src/server.rs | 32 +++++++++++++++++++------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 411df519..53033d5d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -158,7 +158,7 @@ fn default_parameters() -> GenerateParameters { top_k: None, top_p: None, typical_p: None, - do_sample: false, + do_sample: true, max_new_tokens: default_max_new_tokens(), return_full_text: None, stop: Vec::new(), @@ -253,21 +253,29 @@ pub(crate) struct ChatCompletionDelta { } impl ChatCompletionChunk { - pub(crate) fn new(delta: String, created: u64, index: u32) -> Self { + pub(crate) fn new( + model: String, + system_fingerprint: String, + delta: String, + created: u64, + index: u32, + logprobs: Option>, + finish_reason: Option, + ) -> Self { Self { id: "".to_string(), object: "text_completion".to_string(), created, - model: "".to_string(), - system_fingerprint: "".to_string(), + model, + system_fingerprint, choices: vec![ChatCompletionChoice { index, delta: ChatCompletionDelta { role: "assistant".to_string(), content: delta, }, - logprobs: None, - finish_reason: None, + logprobs, + finish_reason, }], } } diff --git a/router/src/server.rs b/router/src/server.rs index 323bf35f..393c3ae0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -21,7 +21,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; -use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; @@ -339,7 +338,7 @@ async fn generate_stream( HeaderMap, Sse>>, ) { - let on_message_callback = |stream_token: StreamResponse| { + let on_message_callback = |_: u32, stream_token: StreamResponse| { let event = Event::default(); event.json_data(stream_token).unwrap() }; @@ -353,7 +352,7 @@ async fn generate_stream( async fn generate_stream_internal( infer: Infer, Json(req): Json, - on_message_callback: impl Fn(StreamResponse) -> Event, + on_message_callback: impl Fn(u32, StreamResponse) -> Event, ) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -397,8 +396,10 @@ async fn generate_stream_internal( match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { + let mut index = 0; // Server-Sent Event stream while let Some(response) = response_stream.next().await { + index += 1; match response { Ok(response) => { match response { @@ -418,8 +419,7 @@ async fn generate_stream_internal( generated_text: None, details: None, }; - - let event = on_message_callback(stream_token); + let event = on_message_callback(index, stream_token); yield Ok(event); } // Yield event for last token and compute timings @@ -483,7 +483,7 @@ async fn generate_stream_internal( }; - let event = on_message_callback(stream_token); + let event = on_message_callback(index, stream_token); yield Ok(event); break; } @@ -550,6 +550,7 @@ async fn generate_stream_internal( )] async fn chat_completions( Extension(infer): Extension, + Extension(info): Extension, Json(req): Json, ) -> Result)> { metrics::increment_counter!("tgi_request_count"); @@ -605,9 +606,14 @@ async fn chat_completions( // switch on stream if stream { - let stream_count = AtomicU32::new(0); + let model_id = info.model_id.clone(); + let system_fingerprint = format!( + "{}-{}", + info.version, + info.docker_label.unwrap_or("native") + ); // pass this callback to the stream generation and build the required event structure - let on_message_callback = move |stream_token: StreamResponse| { + let on_message_callback = move |index: u32, stream_token: StreamResponse| { let event = Event::default(); let current_time = std::time::SystemTime::now() @@ -615,15 +621,15 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - // increment the stream count - stream_count.fetch_add(1, Ordering::SeqCst); - let current_stream_count = stream_count.load(Ordering::SeqCst); - event .json_data(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), stream_token.token.text, current_time, - current_stream_count, + index, + None, + None, )) .unwrap_or_else(|_| { println!("Failed to serialize ChatCompletionChunk");