From ca955b810ebe1285f48941c79a46a7bf4f774255 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 27 Feb 2023 10:31:56 +0100 Subject: [PATCH] feat(router): add legacy route for api-inference support --- router/src/lib.rs | 26 +++++++++++++++++++++++--- router/src/server.rs | 25 ++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 1f23bfd3..b4bb6784 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -47,7 +47,7 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(inline, max_items = 4, example = json!(["photographer"]))] + #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] #[schema(default = "true")] @@ -86,13 +86,33 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +pub(crate) struct LegacyGenerateRequest { + #[schema(example = "My name is Olivier and I")] + pub inputs: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, + #[serde(default)] + #[allow(dead_code)] + pub stream: bool, +} + +impl From for GenerateRequest { + fn from(req: LegacyGenerateRequest) -> Self { + Self { + inputs: req.inputs, + parameters: req.parameters, + } + } +} + #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] id: u32, #[schema(example = "test")] text: String, - #[schema(nullable = true, example = -0.34)] + #[schema(nullable = true, example = - 0.34)] logprob: f32, } @@ -102,7 +122,7 @@ pub struct Token { id: u32, #[schema(example = "test")] text: String, - #[schema(nullable = true, example = -0.34)] + #[schema(nullable = true, example = - 0.34)] logprob: f32, #[schema(example = "false")] special: bool, diff --git a/router/src/server.rs b/router/src/server.rs index de96e397..ee11a1ec 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,7 +2,7 @@ use crate::infer::{InferError, InferStreamResponse}; use crate::{ Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation, + Infer, LegacyGenerateRequest, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -25,6 +25,25 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +/// Compatibility route with api-inference and AzureML +#[instrument(skip(infer))] +async fn legacy_generate( + infer: Extension, + req: Json, +) -> Result)> { + // switch on stream + let req = req.0; + if req.stream { + Ok(generate_stream(infer, Json(req.into())) + .await + .into_response()) + } else { + let (headers, generation) = generate(infer, Json(req.into())).await?; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(vec![generation.0])).into_response()) + } +} + /// Health check method #[instrument(skip(infer))] async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { @@ -84,7 +103,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json, req: Json, -) -> Result)> { +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -404,7 +423,7 @@ pub async fn run( // Create router let app = Router::new() .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) - .route("/", post(generate)) + .route("/", post(legacy_generate)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/", get(health))