diff --git a/router/src/lib.rs b/router/src/lib.rs index 53033d5d..05d5e5a2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -469,6 +469,7 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { + pub index: u32, pub token: Token, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec, diff --git a/router/src/server.rs b/router/src/server.rs index 393c3ae0..47880678 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -338,7 +338,7 @@ async fn generate_stream( HeaderMap, Sse>>, ) { - let on_message_callback = |_: u32, stream_token: StreamResponse| { + let on_message_callback = |stream_token: StreamResponse| { let event = Event::default(); event.json_data(stream_token).unwrap() }; @@ -352,7 +352,7 @@ async fn generate_stream( async fn generate_stream_internal( infer: Infer, Json(req): Json, - on_message_callback: impl Fn(u32, StreamResponse) -> Event, + on_message_callback: impl Fn(StreamResponse) -> Event, ) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -414,12 +414,13 @@ async fn generate_stream_internal( // StreamResponse let stream_token = StreamResponse { + index, token, top_tokens, generated_text: None, details: None, }; - let event = on_message_callback(index, stream_token); + let event = on_message_callback(stream_token); yield Ok(event); } // Yield event for last token and compute timings @@ -476,6 +477,7 @@ async fn generate_stream_internal( tracing::info!(parent: &span, "Success"); let stream_token = StreamResponse { + index, token, top_tokens, generated_text: Some(output_text), @@ -483,7 +485,7 @@ async fn generate_stream_internal( }; - let event = on_message_callback(index, stream_token); + let event = on_message_callback(stream_token); yield Ok(event); break; } @@ -607,13 +609,10 @@ async fn chat_completions( // switch on stream if stream { let model_id = info.model_id.clone(); - let system_fingerprint = format!( - "{}-{}", - info.version, - info.docker_label.unwrap_or("native") - ); + let system_fingerprint = + format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); // pass this callback to the stream generation and build the required event structure - let on_message_callback = move |index: u32, stream_token: StreamResponse| { + let on_message_callback = move |stream_token: StreamResponse| { let event = Event::default(); let current_time = std::time::SystemTime::now() @@ -627,7 +626,7 @@ async fn chat_completions( system_fingerprint.clone(), stream_token.token.text, current_time, - index, + stream_token.index, None, None, ))