mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: improve completions to send a final chunk with usage details
This commit is contained in:
parent
0d06aed02d
commit
c330491223
@ -1211,7 +1211,7 @@ pub(crate) struct ChatTokenizeResponse {
|
||||
#[serde(transparent)]
|
||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[derive(Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct StreamDetails {
|
||||
#[schema(example = "length")]
|
||||
pub finish_reason: FinishReason,
|
||||
@ -1219,9 +1219,11 @@ pub(crate) struct StreamDetails {
|
||||
pub generated_tokens: u32,
|
||||
#[schema(nullable = true, example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
#[schema(example = 1)]
|
||||
pub input_length: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[derive(Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub index: u32,
|
||||
pub token: Token,
|
||||
|
@ -533,7 +533,7 @@ async fn generate_stream_internal(
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, _input_length, response_stream)) => {
|
||||
Ok((_permit, input_length, mut response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
@ -576,6 +576,7 @@ async fn generate_stream_internal(
|
||||
finish_reason: generated_text.finish_reason,
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
input_length,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
@ -801,21 +802,46 @@ async fn completions(
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(Completion::Chunk(Chunk {
|
||||
id: "".to_string(),
|
||||
created: current_time,
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}))
|
||||
}),
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user