feat: support logprobs in streaming and non streaming chat

This commit is contained in:
drbh 2024-01-09 14:04:31 -05:00
parent 65c913b55d
commit 9a79c2f867
2 changed files with 24 additions and 19 deletions

View File

@ -213,7 +213,7 @@ pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
pub logprobs: Option<Vec<f32>>, pub logprobs: Option<Vec<f32>>,
pub finish_reason: Option<String>, pub finish_reason: String,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
@ -227,24 +227,26 @@ impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
ouput: String, output: String,
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool,
) -> Self { ) -> Self {
Self { Self {
id: "".to_string(), id: String::new(),
object: "text_completion".to_string(), object: "text_completion".into(),
created, created,
model, model,
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionComplete { choices: vec![ChatCompletionComplete {
index: 0, index: 0,
message: Message { message: Message {
role: "assistant".to_string(), role: "assistant".into(),
content: ouput, content: output,
}, },
logprobs: None, logprobs: return_logprobs
finish_reason: details.finish_reason.to_string().into(), .then(|| details.tokens.iter().map(|t| t.logprob).collect()),
finish_reason: details.finish_reason.to_string(),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prompt_token_count, prompt_tokens: details.prompt_token_count,
@ -269,7 +271,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
pub logprobs: Option<Vec<f32>>, pub logprobs: Option<f32>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
@ -286,7 +288,7 @@ impl ChatCompletionChunk {
delta: String, delta: String,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<Vec<f32>>, logprobs: Option<f32>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
Self { Self {
@ -340,12 +342,10 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
pub logit_bias: Option<Vec<f32>>, pub logit_bias: Option<Vec<f32>>,
/// UNUSED
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview /// output token returned in the content of message.
/// model.
#[serde(default)] #[serde(default)]
pub logprobs: Option<u32>, pub logprobs: Option<bool>,
/// UNUSED /// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with

View File

@ -564,6 +564,7 @@ async fn chat_completions(
.frequency_penalty .frequency_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0); .map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false);
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) { let inputs = match infer.apply_chat_template(req) {
@ -626,13 +627,16 @@ async fn chat_completions(
stream_token.token.text, stream_token.token.text,
current_time, current_time,
stream_token.index, stream_token.index,
None, logprobs.then(|| stream_token.token.logprob),
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.to_string()),
)) ))
.unwrap_or_else(|e| { .map_or_else(
|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e); println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default() Event::default()
}) },
|data| data,
)
}; };
let (headers, response_stream) = let (headers, response_stream) =
@ -655,6 +659,7 @@ async fn chat_completions(
system_fingerprint, system_fingerprint,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs,
); );
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference