diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml index 382c9ab6..13b80497 100644 --- a/.redocly.lint-ignore.yaml +++ b/.redocly.lint-ignore.yaml @@ -77,3 +77,4 @@ docs/openapi.json: - '#/paths/~1tokenize/post' - '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1completions/post' + - '#/paths/~1v1~1models/get' diff --git a/docs/openapi.json b/docs/openapi.json index f8c62b33..791ef036 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -556,6 +556,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -1747,6 +1778,35 @@ } ] }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, "OutputMessage": { "oneOf": [ { diff --git a/router/src/lib.rs b/router/src/lib.rs index 979f6dd1..a1e1dadf 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1261,6 +1261,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index f273a786..d3d34215 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,6 +24,7 @@ use crate::{ VertexResponse, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,29 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses( +(status = 200, description = "Served model info", body = ModelInfo), +(status = 404, description = "Model not found", body = ErrorResponse), +) +)] +#[instrument(skip(info))] +/// Get model info +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -1505,6 +1529,7 @@ chat_completions, completions, tokenize, metrics, +openai_get_model_info, ), components( schemas( @@ -1557,6 +1582,7 @@ ToolCall, Function, FunctionDefinition, ToolChoice, +ModelInfo, ) ), tags( @@ -2250,7 +2276,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled {