From 0fb864ef4427c32ceebe5bd1aa7aa84300eb6085 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Feb 2024 17:40:46 +0000 Subject: [PATCH] feat: rebase latest changes --- router/src/lib.rs | 19 +++++++++++++++---- router/src/server.rs | 16 ++++++++-------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 494212c5..ae6e7be5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -294,8 +294,8 @@ pub struct CompletionRequest { pub prompt: String, pub max_tokens: Option, - /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will - /// make the output more random, while lower values like 0.2 will make it more + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + /// make the output more random, while lower values like 0.2 will make it more /// focused and deterministic. /// /// We generally recommend altering this or top_p but not both. @@ -307,6 +307,9 @@ pub struct CompletionRequest { pub stream: Option, pub seed: Option, pub suffix: Option, + + pub repetition_penalty: Option, + pub frequency_penalty: Option, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default)] @@ -412,7 +415,7 @@ pub(crate) struct ChatCompletionTopLogprob { logprob: f32, } -#[derive(Clone, Deserialize, Serialize)] +#[derive(Clone, Deserialize, Serialize, Default)] pub(crate) struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, @@ -453,7 +456,15 @@ impl ChatCompletion { } } } - +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct CompletionCompleteChunk { + pub id: String, + pub object: String, + pub created: u64, + pub choices: Vec, + pub model: String, + pub system_fingerprint: String, +} #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index f7ab4160..e9884781 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,10 +4,11 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, - ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, - SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, + ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, + CompletionCompleteChunk, CompletionRequest, Details, ErrorResponse, FinishReason, + GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, + Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -570,7 +571,6 @@ async fn completions( ) -> Result)> { metrics::increment_counter!("tgi_request_count"); - let repetition_penalty = 1.0; let max_new_tokens = req.max_tokens.or(Some(100)); let stream = req.stream.unwrap_or_default(); let seed = req.seed; @@ -582,7 +582,8 @@ async fn completions( parameters: GenerateParameters { best_of: None, temperature: req.temperature, - repetition_penalty: Some(repetition_penalty), + repetition_penalty: req.repetition_penalty, + frequency_penalty: req.frequency_penalty, top_k: None, top_p: req.top_p, typical_p: None, @@ -596,11 +597,10 @@ async fn completions( decoder_input_details: !stream, seed, top_n_tokens: None, + grammar: None, }, }; - - if stream { let on_message_callback = move |stream_token: StreamResponse| { let event = Event::default();