From cade8dbc2be74d7912bd0decea94c7e749db64c3 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 2 Feb 2024 09:57:31 -0500 Subject: [PATCH] feat: accept legacy request format and response --- router/src/lib.rs | 41 +++++++++++++++++++++ router/src/server.rs | 88 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index b7285e65..44eb6010 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -288,6 +288,47 @@ fn default_parameters() -> GenerateParameters { } } +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub struct CompletionRequest { + pub model: String, + pub prompt: String, + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + /// make the output more random, while lower values like 0.2 will make it more + /// focused and deterministic. + /// + /// We generally recommend altering this or top_p but not both. + #[serde(default)] + #[schema(nullable = true, example = 1.0)] + pub temperature: Option, + + pub top_p: Option, + pub stream: Option, + pub seed: Option, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] +pub(crate) struct Completion { + pub id: String, + pub object: String, + #[schema(example = "1706270835")] + pub created: u64, + #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] + pub model: String, + pub system_fingerprint: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema)] +pub(crate) struct CompletionComplete { + pub index: u32, + pub text: String, + pub logprobs: Option>, + pub finish_reason: String, +} + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, diff --git a/router/src/server.rs b/router/src/server.rs index 054ba5a2..89f9d740 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -4,8 +4,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, - ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, + ChatRequest, CompatGenerateRequest, Completion, CompletionRequest, Details, ErrorResponse, + FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, VertexRequest, VertexResponse, }; @@ -532,6 +532,89 @@ async fn generate_stream_internal( (headers, stream) } +/// Generate tokens +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/v1/completions", + request_body = CompletionRequest, + responses( + (status = 200, description = "Generated Text", body = ChatCompletionChunk), + (status = 424, description = "Generation Error", body = ErrorResponse, + example = json ! ({"error": "Request failed during generation"})), + (status = 429, description = "Model is overloaded", body = ErrorResponse, + example = json ! ({"error": "Model is overloaded"})), + (status = 422, description = "Input validation error", body = ErrorResponse, + example = json ! ({"error": "Input validation error"})), + (status = 500, description = "Incomplete generation", body = ErrorResponse, + example = json ! ({"error": "Incomplete generation"})), + ) + )] +#[instrument( + skip_all, + fields( + // parameters = ? req.parameters, + total_time, + validation_time, + queue_time, + inference_time, + time_per_token, + seed, + ) + )] +async fn completions( + infer: Extension, + compute_type: Extension, + Extension(info): Extension, + Json(req): Json, +) -> Result)> { + metrics::increment_counter!("tgi_request_count"); + + let repetition_penalty = 1.0; + let max_new_tokens = req.max_tokens.or(Some(100)); + let stream = req.stream.unwrap_or_default(); + let seed = req.seed; + + // build the request passing some parameters + let generate_request = GenerateRequest { + inputs: req.prompt.to_string(), + parameters: GenerateParameters { + best_of: None, + temperature: req.temperature, + repetition_penalty: Some(repetition_penalty), + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: true, + max_new_tokens, + return_full_text: None, + stop: Vec::new(), + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + }, + }; + + // switch on stream + let response = if stream { + Ok( + generate_stream(infer, compute_type, Json(generate_request.into())) + .await + .into_response(), + ) + } else { + let (headers, Json(generation)) = + generate(infer, compute_type, Json(generate_request.into())).await?; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(vec![generation])).into_response()) + }; + + response +} + /// Generate tokens #[utoipa::path( post, @@ -1071,6 +1154,7 @@ pub async fn run( .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/vertex", post(vertex_compatibility)) + .route("/v1/completions", post(completions)) .route("/tokenize", post(tokenize)) .route("/health", get(health)) .route("/ping", get(health))