mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: accept legacy request format and response
This commit is contained in:
parent
c9f4c1af31
commit
cade8dbc2b
@ -288,6 +288,47 @@ fn default_parameters() -> GenerateParameters {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
pub model: String,
|
||||
pub prompt: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
||||
/// make the output more random, while lower values like 0.2 will make it more
|
||||
/// focused and deterministic.
|
||||
///
|
||||
/// We generally recommend altering this or top_p but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 1.0)]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
pub top_p: Option<f32>,
|
||||
pub stream: Option<bool>,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct Completion {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
#[schema(example = "1706270835")]
|
||||
pub created: u64,
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct CompletionComplete {
|
||||
pub index: u32,
|
||||
pub text: String,
|
||||
pub logprobs: Option<Vec<f32>>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
|
@ -4,8 +4,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
|
||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
@ -532,6 +532,89 @@ async fn generate_stream_internal(
|
||||
(headers, stream)
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/completions",
|
||||
request_body = CompletionRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json ! ({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json ! ({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
// parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn completions(
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let repetition_penalty = 1.0;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let stream = req.stream.unwrap_or_default();
|
||||
let seed = req.seed;
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: req.prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
repetition_penalty: Some(repetition_penalty),
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
// switch on stream
|
||||
let response = if stream {
|
||||
Ok(
|
||||
generate_stream(infer, compute_type, Json(generate_request.into()))
|
||||
.await
|
||||
.into_response(),
|
||||
)
|
||||
} else {
|
||||
let (headers, Json(generation)) =
|
||||
generate(infer, compute_type, Json(generate_request.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
};
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
#[utoipa::path(
|
||||
post,
|
||||
@ -1071,6 +1154,7 @@ pub async fn run(
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
|
Loading…
Reference in New Issue
Block a user