mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
fix: improve completions to send a final chunk with usage details
This commit is contained in:
parent
0d06aed02d
commit
c330491223
@ -1211,7 +1211,7 @@ pub(crate) struct ChatTokenizeResponse {
|
|||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema, Debug)]
|
||||||
pub(crate) struct StreamDetails {
|
pub(crate) struct StreamDetails {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
@ -1219,9 +1219,11 @@ pub(crate) struct StreamDetails {
|
|||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(nullable = true, example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub input_length: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema, Debug)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
|
@ -533,7 +533,7 @@ async fn generate_stream_internal(
|
|||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, _input_length, response_stream)) => {
|
Ok((_permit, input_length, mut response_stream)) => {
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
let mut response_stream = Box::pin(response_stream);
|
let mut response_stream = Box::pin(response_stream);
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
@ -576,6 +576,7 @@ async fn generate_stream_internal(
|
|||||||
finish_reason: generated_text.finish_reason,
|
finish_reason: generated_text.finish_reason,
|
||||||
generated_tokens: generated_text.generated_tokens,
|
generated_tokens: generated_text.generated_tokens,
|
||||||
seed: generated_text.seed,
|
seed: generated_text.seed,
|
||||||
|
input_length,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
@ -801,21 +802,46 @@ async fn completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
event
|
let message = match stream_token.details {
|
||||||
.json_data(Completion::Chunk(Chunk {
|
Some(details) => {
|
||||||
id: "".to_string(),
|
let completion_tokens = details.generated_tokens;
|
||||||
created: current_time,
|
let prompt_tokens = details.input_length;
|
||||||
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
|
|
||||||
|
Completion::Final(CompletionFinal {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.clone(),
|
||||||
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: String::new(),
|
||||||
|
index: index as u32,
|
||||||
|
logprobs: None,
|
||||||
|
text: stream_token.token.text,
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => Completion::Chunk(Chunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
choices: vec![CompletionComplete {
|
choices: vec![CompletionComplete {
|
||||||
finish_reason: "".to_string(),
|
finish_reason: String::new(),
|
||||||
index: index as u32,
|
index: index as u32,
|
||||||
logprobs: None,
|
logprobs: None,
|
||||||
text: stream_token.token.text,
|
text: stream_token.token.text,
|
||||||
}],
|
}],
|
||||||
|
|
||||||
model: model_id.clone(),
|
model: model_id.clone(),
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
}))
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
event
|
||||||
|
.json_data(message)
|
||||||
.unwrap_or_else(|_e| Event::default())
|
.unwrap_or_else(|_e| Event::default())
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user