mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: rebase latest changes
This commit is contained in:
parent
fa9aad3ec4
commit
0fb864ef44
@ -294,8 +294,8 @@ pub struct CompletionRequest {
|
|||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
||||||
/// make the output more random, while lower values like 0.2 will make it more
|
/// make the output more random, while lower values like 0.2 will make it more
|
||||||
/// focused and deterministic.
|
/// focused and deterministic.
|
||||||
///
|
///
|
||||||
/// We generally recommend altering this or top_p but not both.
|
/// We generally recommend altering this or top_p but not both.
|
||||||
@ -307,6 +307,9 @@ pub struct CompletionRequest {
|
|||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub suffix: Option<String>,
|
pub suffix: Option<String>,
|
||||||
|
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
@ -412,7 +415,7 @@ pub(crate) struct ChatCompletionTopLogprob {
|
|||||||
logprob: f32,
|
logprob: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize, Default)]
|
||||||
pub(crate) struct Usage {
|
pub(crate) struct Usage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
@ -453,7 +456,15 @@ impl ChatCompletion {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct CompletionCompleteChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub choices: Vec<CompletionComplete>,
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
}
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChunk {
|
pub(crate) struct ChatCompletionChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
@ -4,10 +4,11 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete,
|
||||||
CompletionRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
CompletionCompleteChunk, CompletionRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken,
|
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer,
|
||||||
SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation,
|
Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
|
||||||
|
TokenizeResponse, Usage, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -570,7 +571,6 @@ async fn completions(
|
|||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let repetition_penalty = 1.0;
|
|
||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
let stream = req.stream.unwrap_or_default();
|
let stream = req.stream.unwrap_or_default();
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
@ -582,7 +582,8 @@ async fn completions(
|
|||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
repetition_penalty: Some(repetition_penalty),
|
repetition_penalty: req.repetition_penalty,
|
||||||
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
@ -596,11 +597,10 @@ async fn completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
grammar: None,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
|
Loading…
Reference in New Issue
Block a user