feat: rebase latest changes

This commit is contained in:
drbh 2024-02-19 17:40:46 +00:00
parent fa9aad3ec4
commit 0fb864ef44
2 changed files with 23 additions and 12 deletions

View File

@ -307,6 +307,9 @@ pub struct CompletionRequest {
pub stream: Option<bool>,
pub seed: Option<u64>,
pub suffix: Option<String>,
pub repetition_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
@ -412,7 +415,7 @@ pub(crate) struct ChatCompletionTopLogprob {
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize)]
#[derive(Clone, Deserialize, Serialize, Default)]
pub(crate) struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
@ -453,7 +456,15 @@ impl ChatCompletion {
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,

View File

@ -4,10 +4,11 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken,
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation,
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete,
CompletionCompleteChunk, CompletionRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer,
Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeResponse, Usage, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -570,7 +571,6 @@ async fn completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let repetition_penalty = 1.0;
let max_new_tokens = req.max_tokens.or(Some(100));
let stream = req.stream.unwrap_or_default();
let seed = req.seed;
@ -582,7 +582,8 @@ async fn completions(
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
repetition_penalty: Some(repetition_penalty),
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
@ -596,11 +597,10 @@ async fn completions(
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: None,
},
};
if stream {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();