mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Repairing prompt token counting.
This commit is contained in:
parent
3a86afc713
commit
3ed703c273
@ -651,6 +651,7 @@ enum CompletionType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletion {
|
impl ChatCompletion {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
@ -659,6 +660,7 @@ impl ChatCompletion {
|
|||||||
details: Details,
|
details: Details,
|
||||||
return_logprobs: bool,
|
return_logprobs: bool,
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
prompt_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let message = match (output, tool_calls) {
|
let message = match (output, tool_calls) {
|
||||||
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
||||||
@ -697,9 +699,9 @@ impl ChatCompletion {
|
|||||||
finish_reason: details.finish_reason.format(true),
|
finish_reason: details.finish_reason.format(true),
|
||||||
}],
|
}],
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: details.prefill.len() as u32,
|
prompt_tokens,
|
||||||
completion_tokens: details.generated_tokens,
|
completion_tokens: details.generated_tokens,
|
||||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
total_tokens: prompt_tokens + details.generated_tokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,9 @@ async fn generate(
|
|||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
let (headers, _, response) =
|
||||||
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
|
||||||
|
Ok((headers, response))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn generate_internal(
|
pub(crate) async fn generate_internal(
|
||||||
@ -279,7 +281,7 @@ pub(crate) async fn generate_internal(
|
|||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
@ -423,7 +425,7 @@ pub(crate) async fn generate_internal(
|
|||||||
generated_text: output_text,
|
generated_text: output_text,
|
||||||
details,
|
details,
|
||||||
};
|
};
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, input_length, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a stream of token using Server-Sent Events
|
/// Generate a stream of token using Server-Sent Events
|
||||||
@ -980,7 +982,9 @@ pub(crate) async fn completions(
|
|||||||
span_clone,
|
span_clone,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
result.map(|(headers, generation)| (index, headers, generation))
|
result.map(|(headers, input_length, generation)| {
|
||||||
|
(index, headers, input_length, generation)
|
||||||
|
})
|
||||||
};
|
};
|
||||||
responses.push(response_future);
|
responses.push(response_future);
|
||||||
}
|
}
|
||||||
@ -1001,7 +1005,7 @@ pub(crate) async fn completions(
|
|||||||
|
|
||||||
let choices = generate_responses
|
let choices = generate_responses
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(index, headers, Json(generation))| {
|
.map(|(index, headers, input_length, Json(generation))| {
|
||||||
let details = generation.details.ok_or((
|
let details = generation.details.ok_or((
|
||||||
// this should never happen but handle if details are missing unexpectedly
|
// this should never happen but handle if details are missing unexpectedly
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
@ -1056,9 +1060,9 @@ pub(crate) async fn completions(
|
|||||||
.and_then(|v| v.to_str().ok()?.parse().ok())
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
|
||||||
prompt_tokens += details.prefill.len() as u32;
|
prompt_tokens += input_length;
|
||||||
completion_tokens += details.generated_tokens;
|
completion_tokens += details.generated_tokens;
|
||||||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
total_tokens += input_length + details.generated_tokens;
|
||||||
|
|
||||||
Ok(CompletionComplete {
|
Ok(CompletionComplete {
|
||||||
finish_reason: details.finish_reason.format(true),
|
finish_reason: details.finish_reason.format(true),
|
||||||
@ -1381,7 +1385,7 @@ pub(crate) async fn chat_completions(
|
|||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) =
|
let (headers, input_length, Json(generation)) =
|
||||||
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
@ -1452,6 +1456,7 @@ pub(crate) async fn chat_completions(
|
|||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
|
input_length,
|
||||||
));
|
));
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
@ -122,7 +122,7 @@ pub(crate) async fn vertex_compatibility(
|
|||||||
span_clone,
|
span_clone,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
.map(|(_, _, Json(generation))| generation.generated_text)
|
||||||
.map_err(|_| {
|
.map_err(|_| {
|
||||||
(
|
(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Loading…
Reference in New Issue
Block a user