From 70483428eeffcd2d90b6b0075fd3a2a762bfdcfb Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 23 Oct 2024 11:59:41 +0200 Subject: [PATCH] update openAPI --- docs/openapi.json | 126 ++++++++++++++++++++++++++++++++++++++++ router/src/sagemaker.rs | 35 ++++++++--- router/src/server.rs | 7 ++- 3 files changed, 159 insertions(+), 9 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index d1b60f4d..5f1946b3 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -316,6 +316,93 @@ } } }, + "/invocations": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "operationId": "sagemaker_compatibility", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Chat Completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SagemakerResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/SagemakerStreamResponse" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + }, "/metrics": { "get": { "tags": [ @@ -1865,6 +1952,45 @@ "type": "string" } }, + "SagemakerRequest": { + "oneOf": [ + { + "$ref": "#/components/schemas/CompatGenerateRequest" + }, + { + "$ref": "#/components/schemas/ChatRequest" + }, + { + "$ref": "#/components/schemas/CompletionRequest" + } + ] + }, + "SagemakerResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/GenerateResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletion" + }, + { + "$ref": "#/components/schemas/CompletionFinal" + } + ] + }, + "SagemakerStreamResponse": { + "oneOf": [ + { + "$ref": "#/components/schemas/StreamResponse" + }, + { + "$ref": "#/components/schemas/ChatCompletionChunk" + }, + { + "$ref": "#/components/schemas/Chunk" + } + ] + }, "SimpleToken": { "type": "object", "required": [ diff --git a/router/src/sagemaker.rs b/router/src/sagemaker.rs index efbecb69..1ba8cabe 100644 --- a/router/src/sagemaker.rs +++ b/router/src/sagemaker.rs @@ -1,11 +1,14 @@ use crate::infer::Infer; use crate::server::{chat_completions, compat_generate, completions, ComputeType}; -use crate::{ChatRequest, CompatGenerateRequest, CompletionRequest, ErrorResponse, Info}; +use crate::{ + ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest, + CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse, +}; use axum::extract::Extension; use axum::http::StatusCode; use axum::response::Response; use axum::Json; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use tracing::instrument; use utoipa::ToSchema; @@ -17,6 +20,26 @@ pub(crate) enum SagemakerRequest { Completion(CompletionRequest), } +/// Used for OpenAPI specs +#[allow(dead_code)] +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum SagemakerResponse { + Generate(GenerateResponse), + Chat(ChatCompletion), + Completion(CompletionFinal), +} + +/// Used for OpenAPI specs +#[allow(dead_code)] +#[derive(Serialize, ToSchema)] +#[serde(untagged)] +pub(crate) enum SagemakerStreamResponse { + Generate(StreamResponse), + Chat(ChatCompletionChunk), + Completion(Chunk), +} + // Generate tokens from Sagemaker request #[utoipa::path( post, @@ -26,12 +49,8 @@ request_body = SagemakerRequest, responses( (status = 200, description = "Generated Chat Completion", content( -("application/json" = GenerateResponse), -("application/json" = ChatCompletion), -("application/json" = CompletionFinal), -("text/event-stream" = StreamResponse), -("text/event-stream" = ChatCompletionChunk), -("text/event-stream" = Chunk), +("application/json" = SagemakerResponse), +("text/event-stream" = SagemakerStreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), diff --git a/router/src/server.rs b/router/src/server.rs index 2561dfd9..5abca058 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,7 +7,10 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; -use crate::sagemaker::{sagemaker_compatibility, SagemakerRequest, __path_sagemaker_compatibility}; +use crate::sagemaker::{ + sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, + __path_sagemaker_compatibility, +}; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; @@ -1543,6 +1546,8 @@ ChatCompletionTopLogprob, ChatCompletion, CompletionRequest, CompletionComplete, +SagemakerResponse, +SagemakerStreamResponse, Chunk, Completion, CompletionFinal,