diff --git a/router/src/lib.rs b/router/src/lib.rs index 0a15c495..d478083b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1211,7 +1211,7 @@ pub(crate) struct ChatTokenizeResponse { #[serde(transparent)] pub(crate) struct TokenizeResponse(Vec); -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Debug)] pub(crate) struct StreamDetails { #[schema(example = "length")] pub finish_reason: FinishReason, @@ -1219,9 +1219,11 @@ pub(crate) struct StreamDetails { pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, + #[schema(example = 1)] + pub input_length: u32, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, ToSchema, Debug)] pub(crate) struct StreamResponse { pub index: u32, pub token: Token, diff --git a/router/src/server.rs b/router/src/server.rs index 8c0bd762..0e1635bd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -533,7 +533,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, response_stream)) => { + Ok((_permit, input_length, mut response_stream)) => { let mut index = 0; let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream @@ -576,6 +576,7 @@ async fn generate_stream_internal( finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, + input_length, }), false => None, }; @@ -801,21 +802,46 @@ async fn completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - event - .json_data(Completion::Chunk(Chunk { - id: "".to_string(), - created: current_time, + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: String::new(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, choices: vec![CompletionComplete { - finish_reason: "".to_string(), + finish_reason: String::new(), index: index as u32, logprobs: None, text: stream_token.token.text, }], - model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - })) + }), + }; + + event + .json_data(message) .unwrap_or_else(|_e| Event::default()) };