feat: rebase latest changes

This commit is contained in:
drbh 2024-02-19 17:40:46 +00:00
parent fa9aad3ec4
commit 0fb864ef44
2 changed files with 23 additions and 12 deletions

View File

@ -307,6 +307,9 @@ pub struct CompletionRequest {
pub stream: Option<bool>, pub stream: Option<bool>,
pub seed: Option<u64>, pub seed: Option<u64>,
pub suffix: Option<String>, pub suffix: Option<String>,
pub repetition_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
@ -412,7 +415,7 @@ pub(crate) struct ChatCompletionTopLogprob {
logprob: f32, logprob: f32,
} }
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize, Default)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
@ -453,7 +456,15 @@ impl ChatCompletion {
} }
} }
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,

View File

@ -4,10 +4,11 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete,
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, CompletionCompleteChunk, CompletionRequest, Details, ErrorResponse, FinishReason,
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer,
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeResponse, Usage, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -570,7 +571,6 @@ async fn completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let repetition_penalty = 1.0;
let max_new_tokens = req.max_tokens.or(Some(100)); let max_new_tokens = req.max_tokens.or(Some(100));
let stream = req.stream.unwrap_or_default(); let stream = req.stream.unwrap_or_default();
let seed = req.seed; let seed = req.seed;
@ -582,7 +582,8 @@ async fn completions(
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
repetition_penalty: Some(repetition_penalty), repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
@ -596,11 +597,10 @@ async fn completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None,
}, },
}; };
if stream { if stream {
let on_message_callback = move |stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();