From 9a79c2f86724d2e13847d9753e3ca844c35f9e2f Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 9 Jan 2024 14:04:31 -0500 Subject: [PATCH] feat: support logprobs in streaming and non streaming chat --- router/src/lib.rs | 28 ++++++++++++++-------------- router/src/server.rs | 15 ++++++++++----- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index b07b9789..d5394f61 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -213,7 +213,7 @@ pub(crate) struct ChatCompletionComplete { pub index: u32, pub message: Message, pub logprobs: Option>, - pub finish_reason: Option, + pub finish_reason: String, } #[derive(Clone, Deserialize, Serialize)] @@ -227,24 +227,26 @@ impl ChatCompletion { pub(crate) fn new( model: String, system_fingerprint: String, - ouput: String, + output: String, created: u64, details: Details, + return_logprobs: bool, ) -> Self { Self { - id: "".to_string(), - object: "text_completion".to_string(), + id: String::new(), + object: "text_completion".into(), created, model, system_fingerprint, choices: vec![ChatCompletionComplete { index: 0, message: Message { - role: "assistant".to_string(), - content: ouput, + role: "assistant".into(), + content: output, }, - logprobs: None, - finish_reason: details.finish_reason.to_string().into(), + logprobs: return_logprobs + .then(|| details.tokens.iter().map(|t| t.logprob).collect()), + finish_reason: details.finish_reason.to_string(), }], usage: Usage { prompt_tokens: details.prompt_token_count, @@ -269,7 +271,7 @@ pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChoice { pub index: u32, pub delta: ChatCompletionDelta, - pub logprobs: Option>, + pub logprobs: Option, pub finish_reason: Option, } @@ -286,7 +288,7 @@ impl ChatCompletionChunk { delta: String, created: u64, index: u32, - logprobs: Option>, + logprobs: Option, finish_reason: Option, ) -> Self { Self { @@ -340,12 +342,10 @@ pub(crate) struct ChatRequest { #[serde(default)] pub logit_bias: Option>, - /// UNUSED /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each - /// output token returned in the content of message. This option is currently not available on the gpt-4-vision-preview - /// model. + /// output token returned in the content of message. #[serde(default)] - pub logprobs: Option, + pub logprobs: Option, /// UNUSED /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with diff --git a/router/src/server.rs b/router/src/server.rs index 536be9d3..04b8121c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -564,6 +564,7 @@ async fn chat_completions( .frequency_penalty // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) .map(|x| x + 2.0); + let logprobs = req.logprobs.unwrap_or(false); // apply chat template to flatten the request into a single input let inputs = match infer.apply_chat_template(req) { @@ -626,13 +627,16 @@ async fn chat_completions( stream_token.token.text, current_time, stream_token.index, - None, + logprobs.then(|| stream_token.token.logprob), stream_token.details.map(|d| d.finish_reason.to_string()), )) - .unwrap_or_else(|e| { - println!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - }) + .map_or_else( + |e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }, + |data| data, + ) }; let (headers, response_stream) = @@ -655,6 +659,7 @@ async fn chat_completions( system_fingerprint, current_time, generation.details.unwrap(), + logprobs, ); // wrap generation inside a Vec to match api-inference