diff --git a/router/src/lib.rs b/router/src/lib.rs index 7d7461fd..f648f597 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -224,13 +224,19 @@ pub(crate) struct Usage { } impl ChatCompletion { - pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self { + pub(crate) fn new( + model: String, + system_fingerprint: String, + ouput: String, + created: u64, + details: Details, + ) -> Self { Self { id: "".to_string(), object: "text_completion".to_string(), created, - model: "".to_string(), - system_fingerprint: "".to_string(), + model, + system_fingerprint, choices: vec![ChatCompletionComplete { index: 0, message: Message { diff --git a/router/src/server.rs b/router/src/server.rs index c44aeefe..5e2574d4 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -558,12 +558,12 @@ async fn chat_completions( ) -> Result)> { metrics::increment_counter!("tgi_request_count"); - // extract the values we need for the chat request let stream = req.stream; - let max_new_tokens = match req.max_tokens { - Some(max_new_tokens) => Some(max_new_tokens), - None => Some(100), - }; + let max_new_tokens = req.max_tokens.or(Some(100)); + let repetition_penalty = req + .frequency_penalty + // rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) + .map(|x| x + 2.0); // apply chat template to flatten the request into a single input let inputs = match infer.apply_chat_template(req) { @@ -587,11 +587,11 @@ async fn chat_completions( parameters: GenerateParameters { best_of: None, temperature: None, - repetition_penalty: None, + repetition_penalty, top_k: None, top_p: None, typical_p: None, - do_sample: false, + do_sample: true, max_new_tokens, return_full_text: None, stop: Vec::new(), @@ -604,11 +604,12 @@ async fn chat_completions( }, }; + // static values that will be returned in all cases + let model_id = info.model_id.clone(); + let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); + // switch on stream if stream { - let model_id = info.model_id.clone(); - let system_fingerprint = - format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); // pass this callback to the stream generation and build the required event structure let on_message_callback = move |stream_token: StreamResponse| { let event = Event::default(); @@ -650,6 +651,8 @@ async fn chat_completions( // build the complete response object with the full text let response = ChatCompletion::new( generation.generated_text, + model_id, + system_fingerprint, current_time, generation.details.unwrap(), );