From ca8ad2dbee891278c0b705ae34a1ac59d0a3601e Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 31 Jul 2024 15:34:48 +0000 Subject: [PATCH] feat: prefer stop over eos_token to align with openai finish_reason --- router/src/lib.rs | 11 ++++++++++- router/src/server.rs | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..1e6feeb7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -619,7 +619,7 @@ impl ChatCompletion { message, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), }], usage: Usage { prompt_tokens: details.prefill.len() as u32, @@ -1117,6 +1117,15 @@ impl std::fmt::Display for FinishReason { } } +impl FinishReason { + pub fn format(&self, use_stop: bool) -> String { + match self { + FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(), + _ => self.to_string(), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..c30f00c8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -919,7 +919,7 @@ async fn completions( total_tokens += details.prefill.len() as u32 + details.generated_tokens; Ok(CompletionComplete { - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), index: index as u32, logprobs: None, text: generation.generated_text, @@ -1159,7 +1159,7 @@ async fn chat_completions( tool_calls, current_time, logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + stream_token.details.map(|d| d.finish_reason.format(true)), ), )) .unwrap_or_else(|e| {