feat: accept legacy request format and response

This commit is contained in:
drbh 2024-02-02 09:57:31 -05:00
parent c9f4c1af31
commit cade8dbc2b
2 changed files with 127 additions and 2 deletions

View File

@ -288,6 +288,47 @@ fn default_parameters() -> GenerateParameters {
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
/// make the output more random, while lower values like 0.2 will make it more
/// focused and deterministic.
///
/// We generally recommend altering this or top_p but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stream: Option<bool>,
pub seed: Option<u64>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Completion {
pub id: String,
pub object: String,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<CompletionComplete>,
pub usage: Usage,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionComplete {
pub index: u32,
pub text: String,
pub logprobs: Option<Vec<f32>>,
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion {
pub id: String,

View File

@ -4,8 +4,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse,
FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse,
};
@ -532,6 +532,89 @@ async fn generate_stream_internal(
(headers, stream)
}
/// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
(status = 200, description = "Generated Text", body = ChatCompletionChunk),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
// parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn completions(
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let repetition_penalty = 1.0;
let max_new_tokens = req.max_tokens.or(Some(100));
let stream = req.stream.unwrap_or_default();
let seed = req.seed;
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: req.prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
repetition_penalty: Some(repetition_penalty),
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
max_new_tokens,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
},
};
// switch on stream
let response = if stream {
Ok(
generate_stream(infer, compute_type, Json(generate_request.into()))
.await
.into_response(),
)
} else {
let (headers, Json(generation)) =
generate(infer, compute_type, Json(generate_request.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response())
};
response
}
/// Generate tokens
#[utoipa::path(
post,
@ -1071,6 +1154,7 @@ pub async fn run(
.route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions))
.route("/vertex", post(vertex_compatibility))
.route("/v1/completions", post(completions))
.route("/tokenize", post(tokenize))
.route("/health", get(health))
.route("/ping", get(health))