From 8398d4f436cffe142586484ab0623ead063e8803 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Aug 2024 16:00:48 +0000 Subject: [PATCH] feat: add /v1/models endpoint --- router/src/lib.rs | 28 ++++++++++++++++++++++++++++ router/src/server.rs | 27 ++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c46..d874c38b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1240,6 +1240,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a33..88a2e9ce 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,8 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,25 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses((status = 200, description = "Served model info", body = ModelInfo)) +)] +#[instrument] +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -2208,7 +2228,7 @@ async fn start( // Define base and health routes let mut base_routes = Router::new() - .route("/", post(compat_generate)) + .route("/", post(openai_get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) @@ -2246,7 +2266,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled {