Repairing prompt token counting.

This commit is contained in:
Nicolas Patry 2024-12-04 19:18:22 +01:00
parent 3a86afc713
commit 3ed703c273
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 18 additions and 11 deletions

View File

@ -651,6 +651,7 @@ enum CompletionType {
} }
impl ChatCompletion { impl ChatCompletion {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
@ -659,6 +660,7 @@ impl ChatCompletion {
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
prompt_tokens: u32,
) -> Self { ) -> Self {
let message = match (output, tool_calls) { let message = match (output, tool_calls) {
(Some(content), None) => OutputMessage::ChatMessage(TextMessage { (Some(content), None) => OutputMessage::ChatMessage(TextMessage {
@ -697,9 +699,9 @@ impl ChatCompletion {
finish_reason: details.finish_reason.format(true), finish_reason: details.finish_reason.format(true),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens,
completion_tokens: details.generated_tokens, completion_tokens: details.generated_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens, total_tokens: prompt_tokens + details.generated_tokens,
}, },
} }
} }

View File

@ -271,7 +271,9 @@ async fn generate(
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
generate_internal(infer, ComputeType(compute_type), Json(req), span).await let (headers, _, response) =
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
Ok((headers, response))
} }
pub(crate) async fn generate_internal( pub(crate) async fn generate_internal(
@ -279,7 +281,7 @@ pub(crate) async fn generate_internal(
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
span: tracing::Span, span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
@ -423,7 +425,7 @@ pub(crate) async fn generate_internal(
generated_text: output_text, generated_text: output_text,
details, details,
}; };
Ok((headers, Json(response))) Ok((headers, input_length, Json(response)))
} }
/// Generate a stream of token using Server-Sent Events /// Generate a stream of token using Server-Sent Events
@ -980,7 +982,9 @@ pub(crate) async fn completions(
span_clone, span_clone,
) )
.await; .await;
result.map(|(headers, generation)| (index, headers, generation)) result.map(|(headers, input_length, generation)| {
(index, headers, input_length, generation)
})
}; };
responses.push(response_future); responses.push(response_future);
} }
@ -1001,7 +1005,7 @@ pub(crate) async fn completions(
let choices = generate_responses let choices = generate_responses
.into_iter() .into_iter()
.map(|(index, headers, Json(generation))| { .map(|(index, headers, input_length, Json(generation))| {
let details = generation.details.ok_or(( let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly // this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
@ -1056,9 +1060,9 @@ pub(crate) async fn completions(
.and_then(|v| v.to_str().ok()?.parse().ok()) .and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0); .unwrap_or(0);
prompt_tokens += details.prefill.len() as u32; prompt_tokens += input_length;
completion_tokens += details.generated_tokens; completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens; total_tokens += input_length + details.generated_tokens;
Ok(CompletionComplete { Ok(CompletionComplete {
finish_reason: details.finish_reason.format(true), finish_reason: details.finish_reason.format(true),
@ -1381,7 +1385,7 @@ pub(crate) async fn chat_completions(
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = let (headers, input_length, Json(generation)) =
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?; generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -1452,6 +1456,7 @@ pub(crate) async fn chat_completions(
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls, tool_calls,
input_length,
)); ));
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference

View File

@ -122,7 +122,7 @@ pub(crate) async fn vertex_compatibility(
span_clone, span_clone,
) )
.await .await
.map(|(_, Json(generation))| generation.generated_text) .map(|(_, _, Json(generation))| generation.generated_text)
.map_err(|_| { .map_err(|_| {
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,