mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: add /v1/models endpoint
This commit is contained in:
parent
38773453ae
commit
93194a5075
@ -1243,6 +1243,34 @@ pub(crate) struct ErrorResponse {
|
||||
pub error_type: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct ModelInfo {
|
||||
#[schema(example = "gpt2")]
|
||||
pub id: String,
|
||||
#[schema(example = "model")]
|
||||
pub object: String,
|
||||
#[schema(example = 1686935002)]
|
||||
pub created: u64,
|
||||
#[schema(example = "openai")]
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct ModelsInfo {
|
||||
#[schema(example = "list")]
|
||||
pub object: String,
|
||||
pub data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
impl Default for ModelsInfo {
|
||||
fn default() -> Self {
|
||||
ModelsInfo {
|
||||
object: "list".to_string(),
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -24,6 +24,7 @@ use crate::{
|
||||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
||||
use crate::{ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
@ -116,6 +117,25 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||
Json(info.0)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/models",
|
||||
responses((status = 200, description = "Served model info", body = ModelInfo))
|
||||
)]
|
||||
#[instrument]
|
||||
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||
Json(ModelsInfo {
|
||||
data: vec![ModelInfo {
|
||||
id: info.0.model_id.clone(),
|
||||
object: "model".to_string(),
|
||||
created: 0, // TODO: determine how to get this
|
||||
owned_by: info.0.model_id.clone(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
@ -2206,7 +2226,7 @@ async fn start(
|
||||
|
||||
// Define base and health routes
|
||||
let mut base_routes = Router::new()
|
||||
.route("/", post(compat_generate))
|
||||
.route("/", post(openai_get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
@ -2244,7 +2264,8 @@ async fn start(
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
.route("/metrics", get(metrics));
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/v1/models", get(openai_get_model_info));
|
||||
|
||||
// Conditional AWS Sagemaker route
|
||||
let aws_sagemaker_route = if messages_api_enabled {
|
||||
|
Loading…
Reference in New Issue
Block a user