fix: improve completions to send a final chunk with usage details

This commit is contained in:
drbh 2024-07-30 21:02:32 +00:00
parent 0d06aed02d
commit c330491223
2 changed files with 38 additions and 10 deletions

View File

@ -1211,7 +1211,7 @@ pub(crate) struct ChatTokenizeResponse {
#[serde(transparent)] #[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>); pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema, Debug)]
pub(crate) struct StreamDetails { pub(crate) struct StreamDetails {
#[schema(example = "length")] #[schema(example = "length")]
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
@ -1219,9 +1219,11 @@ pub(crate) struct StreamDetails {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(nullable = true, example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema, Debug)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub index: u32, pub index: u32,
pub token: Token, pub token: Token,

View File

@ -533,7 +533,7 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, response_stream)) => { Ok((_permit, input_length, mut response_stream)) => {
let mut index = 0; let mut index = 0;
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
@ -576,6 +576,7 @@ async fn generate_stream_internal(
finish_reason: generated_text.finish_reason, finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed, seed: generated_text.seed,
input_length,
}), }),
false => None, false => None,
}; };
@ -801,21 +802,46 @@ async fn completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
event let message = match stream_token.details {
.json_data(Completion::Chunk(Chunk { Some(details) => {
id: "".to_string(), let completion_tokens = details.generated_tokens;
created: current_time, let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
finish_reason: "".to_string(), finish_reason: String::new(),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: stream_token.token.text, text: stream_token.token.text,
}], }],
model: model_id.clone(), model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(), system_fingerprint: system_fingerprint.clone(),
})) }),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default()) .unwrap_or_else(|_e| Event::default())
}; };