mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: accept legacy request format and response
This commit is contained in:
parent
c9f4c1af31
commit
cade8dbc2b
@ -288,6 +288,47 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub prompt: String,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
||||||
|
/// make the output more random, while lower values like 0.2 will make it more
|
||||||
|
/// focused and deterministic.
|
||||||
|
///
|
||||||
|
/// We generally recommend altering this or top_p but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
pub(crate) struct Completion {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
#[schema(example = "1706270835")]
|
||||||
|
pub created: u64,
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
pub model: String,
|
||||||
|
pub system_fingerprint: String,
|
||||||
|
pub choices: Vec<CompletionComplete>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct CompletionComplete {
|
||||||
|
pub index: u32,
|
||||||
|
pub text: String,
|
||||||
|
pub logprobs: Option<Vec<f32>>,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletion {
|
pub(crate) struct ChatCompletion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
@ -4,8 +4,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
|||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
|
||||||
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
|
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
|
||||||
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
|
||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
@ -532,6 +532,89 @@ async fn generate_stream_internal(
|
|||||||
(headers, stream)
|
(headers, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate tokens
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v1/completions",
|
||||||
|
request_body = CompletionRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
// parameters = ? req.parameters,
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn completions(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
|
Extension(info): Extension<Info>,
|
||||||
|
Json(req): Json<CompletionRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
let repetition_penalty = 1.0;
|
||||||
|
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||||
|
let stream = req.stream.unwrap_or_default();
|
||||||
|
let seed = req.seed;
|
||||||
|
|
||||||
|
// build the request passing some parameters
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs: req.prompt.to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature: req.temperature,
|
||||||
|
repetition_penalty: Some(repetition_penalty),
|
||||||
|
top_k: None,
|
||||||
|
top_p: req.top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// switch on stream
|
||||||
|
let response = if stream {
|
||||||
|
Ok(
|
||||||
|
generate_stream(infer, compute_type, Json(generate_request.into()))
|
||||||
|
.await
|
||||||
|
.into_response(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let (headers, Json(generation)) =
|
||||||
|
generate(infer, compute_type, Json(generate_request.into())).await?;
|
||||||
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
|
};
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
@ -1071,6 +1154,7 @@ pub async fn run(
|
|||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
|
.route("/v1/completions", post(completions))
|
||||||
.route("/tokenize", post(tokenize))
|
.route("/tokenize", post(tokenize))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
|
Loading…
Reference in New Issue
Block a user