diff --git a/router/src/lib.rs b/router/src/lib.rs index 00c493c4..b1569d0d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -651,6 +651,7 @@ enum CompletionType { } impl ChatCompletion { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( model: String, system_fingerprint: String, @@ -659,6 +660,7 @@ impl ChatCompletion { details: Details, return_logprobs: bool, tool_calls: Option>, + prompt_tokens: u32, ) -> Self { let message = match (output, tool_calls) { (Some(content), None) => OutputMessage::ChatMessage(TextMessage { @@ -697,9 +699,9 @@ impl ChatCompletion { finish_reason: details.finish_reason.format(true), }], usage: Usage { - prompt_tokens: details.prefill.len() as u32, + prompt_tokens, completion_tokens: details.generated_tokens, - total_tokens: details.prefill.len() as u32 + details.generated_tokens, + total_tokens: prompt_tokens + details.generated_tokens, }, } } diff --git a/router/src/server.rs b/router/src/server.rs index f253cb63..c1fb7aab 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -271,7 +271,9 @@ async fn generate( Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); - generate_internal(infer, ComputeType(compute_type), Json(req), span).await + let (headers, _, response) = + generate_internal(infer, ComputeType(compute_type), Json(req), span).await?; + Ok((headers, response)) } pub(crate) async fn generate_internal( @@ -279,7 +281,7 @@ pub(crate) async fn generate_internal( ComputeType(compute_type): ComputeType, Json(req): Json, span: tracing::Span, -) -> Result<(HeaderMap, Json), (StatusCode, Json)> { +) -> Result<(HeaderMap, u32, Json), (StatusCode, Json)> { let start_time = Instant::now(); metrics::counter!("tgi_request_count").increment(1); @@ -423,7 +425,7 @@ pub(crate) async fn generate_internal( generated_text: output_text, details, }; - Ok((headers, Json(response))) + Ok((headers, input_length, Json(response))) } /// Generate a stream of token using Server-Sent Events @@ -980,7 +982,9 @@ pub(crate) async fn completions( span_clone, ) .await; - result.map(|(headers, generation)| (index, headers, generation)) + result.map(|(headers, input_length, generation)| { + (index, headers, input_length, generation) + }) }; responses.push(response_future); } @@ -1001,7 +1005,7 @@ pub(crate) async fn completions( let choices = generate_responses .into_iter() - .map(|(index, headers, Json(generation))| { + .map(|(index, headers, input_length, Json(generation))| { let details = generation.details.ok_or(( // this should never happen but handle if details are missing unexpectedly StatusCode::INTERNAL_SERVER_ERROR, @@ -1056,9 +1060,9 @@ pub(crate) async fn completions( .and_then(|v| v.to_str().ok()?.parse().ok()) .unwrap_or(0); - prompt_tokens += details.prefill.len() as u32; + prompt_tokens += input_length; completion_tokens += details.generated_tokens; - total_tokens += details.prefill.len() as u32 + details.generated_tokens; + total_tokens += input_length + details.generated_tokens; Ok(CompletionComplete { finish_reason: details.finish_reason.format(true), @@ -1381,7 +1385,7 @@ pub(crate) async fn chat_completions( let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { - let (headers, Json(generation)) = + let (headers, input_length, Json(generation)) = generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; let current_time = std::time::SystemTime::now() @@ -1452,6 +1456,7 @@ pub(crate) async fn chat_completions( generation.details.unwrap(), logprobs, tool_calls, + input_length, )); // wrap generation inside a Vec to match api-inference diff --git a/router/src/vertex.rs b/router/src/vertex.rs index a532c9ec..0a8c2278 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -122,7 +122,7 @@ pub(crate) async fn vertex_compatibility( span_clone, ) .await - .map(|(_, Json(generation))| generation.generated_text) + .map(|(_, _, Json(generation))| generation.generated_text) .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR,