feat: support logprobs in streaming and non streaming chat

This commit is contained in:
drbh 2024-01-09 14:04:31 -05:00
parent 65c913b55d
commit 9a79c2f867
2 changed files with 24 additions and 19 deletions

View File

@ -213,7 +213,7 @@ pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: Option<String>,
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize)]
@ -227,24 +227,26 @@ impl ChatCompletion {
pub(crate) fn new(
model: String,
system_fingerprint: String,
ouput: String,
output: String,
created: u64,
details: Details,
return_logprobs: bool,
) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
id: String::new(),
object: "text_completion".into(),
created,
model,
system_fingerprint,
choices: vec![ChatCompletionComplete {
index: 0,
message: Message {
role: "assistant".to_string(),
content: ouput,
role: "assistant".into(),
content: output,
},
logprobs: None,
finish_reason: details.finish_reason.to_string().into(),
logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()),
finish_reason: details.finish_reason.to_string(),
}],
usage: Usage {
prompt_tokens: details.prompt_token_count,
@ -269,7 +271,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice {
pub index: u32,
pub delta: ChatCompletionDelta,
pub logprobs: Option<Vec<f32>>,
pub logprobs: Option<f32>,
pub finish_reason: Option<String>,
}
@ -286,7 +288,7 @@ impl ChatCompletionChunk {
delta: String,
created: u64,
index: u32,
logprobs: Option<Vec<f32>>,
logprobs: Option<f32>,
finish_reason: Option<String>,
) -> Self {
Self {
@ -340,12 +342,10 @@ pub(crate) struct ChatRequest {
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// UNUSED
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview
/// model.
/// output token returned in the content of message.
#[serde(default)]
pub logprobs: Option<u32>,
pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with

View File

@ -564,6 +564,7 @@ async fn chat_completions(
.frequency_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false);
// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) {
@ -626,13 +627,16 @@ async fn chat_completions(
stream_token.token.text,
current_time,
stream_token.index,
None,
logprobs.then(|| stream_token.token.logprob),
stream_token.details.map(|d| d.finish_reason.to_string()),
))
.unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
.map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
},
|data| data,
)
};
let (headers, response_stream) =
@ -655,6 +659,7 @@ async fn chat_completions(
system_fingerprint,
current_time,
generation.details.unwrap(),
logprobs,
);
// wrap generation inside a Vec to match api-inference