feat: add /v1/models endpoint

This commit is contained in:
drbh 2024-08-19 16:00:48 +00:00
parent cfa73b5c99
commit 8398d4f436
2 changed files with 52 additions and 3 deletions

View File

@ -1240,6 +1240,34 @@ pub(crate) struct ErrorResponse {
pub error_type: String, pub error_type: String,
} }
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelInfo {
#[schema(example = "gpt2")]
pub id: String,
#[schema(example = "model")]
pub object: String,
#[schema(example = 1686935002)]
pub created: u64,
#[schema(example = "openai")]
pub owned_by: String,
}
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelsInfo {
#[schema(example = "list")]
pub object: String,
pub data: Vec<ModelInfo>,
}
impl Default for ModelsInfo {
fn default() -> Self {
ModelsInfo {
object: "list".to_string(),
data: Vec::new(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -23,7 +23,8 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -116,6 +117,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0) Json(info.0)
} }
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v1/models",
responses((status = 200, description = "Served model info", body = ModelInfo))
)]
#[instrument]
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
Json(ModelsInfo {
data: vec![ModelInfo {
id: info.0.model_id.clone(),
object: "model".to_string(),
created: 0, // TODO: determine how to get this
owned_by: info.0.model_id.clone(),
}],
..Default::default()
})
}
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
@ -2208,7 +2228,7 @@ async fn start(
// Define base and health routes // Define base and health routes
let mut base_routes = Router::new() let mut base_routes = Router::new()
.route("/", post(compat_generate)) .route("/", post(openai_get_model_info))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
@ -2246,7 +2266,8 @@ async fn start(
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route // Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled { let aws_sagemaker_route = if messages_api_enabled {