add new openapi schema

This commit is contained in:
OlivierDehaene 2023-05-09 13:17:34 +02:00
parent 3225fed42e
commit 4fb69b9d1c
3 changed files with 107 additions and 16 deletions

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.6.0" "version": "0.7.0-dev"
}, },
"paths": { "paths": {
"/": { "/": {
@ -33,7 +33,19 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "See /generate or /generate_stream" "description": "Generated Text",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/StreamResponse"
}
}
}
}, },
"422": { "422": {
"description": "Input validation error", "description": "Input validation error",
@ -584,11 +596,73 @@
"type": "object", "type": "object",
"required": [ "required": [
"model_id", "model_id",
"model_dtype",
"model_device_type",
"max_concurrent_requests",
"max_best_of",
"max_stop_sequences",
"max_input_length",
"max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"version" "version"
], ],
"properties": { "properties": {
"docker_label": {
"type": "string",
"example": "null",
"nullable": true
},
"max_batch_total_tokens": {
"type": "integer",
"format": "int32",
"example": "32000",
"minimum": 0.0
},
"max_best_of": {
"type": "integer",
"example": "2",
"minimum": 0.0
},
"max_concurrent_requests": {
"type": "integer",
"description": "Router Parameters",
"example": "128",
"minimum": 0.0
},
"max_input_length": {
"type": "integer",
"example": "1024",
"minimum": 0.0
},
"max_stop_sequences": {
"type": "integer",
"example": "4",
"minimum": 0.0
},
"max_total_tokens": {
"type": "integer",
"example": "2048",
"minimum": 0.0
},
"max_waiting_tokens": {
"type": "integer",
"example": "20",
"minimum": 0.0
},
"model_device_type": {
"type": "string",
"example": "cuda"
},
"model_dtype": {
"type": "string",
"example": "torch.float16"
},
"model_id": { "model_id": {
"type": "string", "type": "string",
"description": "Model info",
"example": "bigscience/blomm-560m" "example": "bigscience/blomm-560m"
}, },
"model_pipeline_tag": { "model_pipeline_tag": {
@ -606,9 +680,20 @@
"example": "null", "example": "null",
"nullable": true "nullable": true
}, },
"validation_workers": {
"type": "integer",
"example": "2",
"minimum": 0.0
},
"version": { "version": {
"type": "string", "type": "string",
"description": "Router Info",
"example": "0.5.0" "example": "0.5.0"
},
"waiting_served_ratio": {
"type": "number",
"format": "float",
"example": "1.2"
} }
} }
}, },

View File

@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo};
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await, false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
}; };
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo { pub async fn get_model_info(
model_id: &str,
revision: &str,
token: Option<String>,
) -> Option<HubModelInfo> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Poor man's urlencode // Poor man's urlencode
let revision = revision.replace('/', "%2F"); let revision = revision.replace('/', "%2F");
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}"); let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
let mut builder = client.get(url); let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token { if let Some(token) = token {
builder = builder.bearer_auth(token); builder = builder.bearer_auth(token);
} }
let model_info = builder let response = builder.send().await.ok()?;
.send()
.await if response.status().is_success() {
.expect("Could not connect to hf.co") return serde_json::from_str(&response.text().await.ok()?).ok();
.text() }
.await None
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info")
} }

View File

@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
path = "/", path = "/",
request_body = CompatGenerateRequest, request_body = CompatGenerateRequest,
responses( responses(
(status = 200, description = "See /generate or /generate_stream", (status = 200, description = "Generated Text",
content( content(
("application/json" = GenerateResponse), ("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse), ("text/event-stream" = StreamResponse),