mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add new openapi schema
This commit is contained in:
parent
3225fed42e
commit
4fb69b9d1c
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.6.0"
|
"version": "0.7.0-dev"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
@ -33,7 +33,19 @@
|
|||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "See /generate or /generate_stream"
|
"description": "Generated Text",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Input validation error",
|
"description": "Input validation error",
|
||||||
@ -584,11 +596,73 @@
|
|||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"model_id",
|
"model_id",
|
||||||
|
"model_dtype",
|
||||||
|
"model_device_type",
|
||||||
|
"max_concurrent_requests",
|
||||||
|
"max_best_of",
|
||||||
|
"max_stop_sequences",
|
||||||
|
"max_input_length",
|
||||||
|
"max_total_tokens",
|
||||||
|
"waiting_served_ratio",
|
||||||
|
"max_batch_total_tokens",
|
||||||
|
"max_waiting_tokens",
|
||||||
|
"validation_workers",
|
||||||
"version"
|
"version"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"docker_label": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"max_batch_total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": "32000",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_best_of": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_concurrent_requests": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Router Parameters",
|
||||||
|
"example": "128",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_input_length": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "1024",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_stop_sequences": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "4",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2048",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_waiting_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "20",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"model_device_type": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "cuda"
|
||||||
|
},
|
||||||
|
"model_dtype": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "torch.float16"
|
||||||
|
},
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"description": "Model info",
|
||||||
"example": "bigscience/blomm-560m"
|
"example": "bigscience/blomm-560m"
|
||||||
},
|
},
|
||||||
"model_pipeline_tag": {
|
"model_pipeline_tag": {
|
||||||
@ -606,9 +680,20 @@
|
|||||||
"example": "null",
|
"example": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"validation_workers": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
"version": {
|
"version": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"description": "Router Info",
|
||||||
"example": "0.5.0"
|
"example": "0.5.0"
|
||||||
|
},
|
||||||
|
"waiting_served_ratio": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"example": "1.2"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
|
|||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
/// get model info from the Huggingface Hub
|
||||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
pub async fn get_model_info(
|
||||||
|
model_id: &str,
|
||||||
|
revision: &str,
|
||||||
|
token: Option<String>,
|
||||||
|
) -> Option<HubModelInfo> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
// Poor man's urlencode
|
// Poor man's urlencode
|
||||||
let revision = revision.replace('/', "%2F");
|
let revision = revision.replace('/', "%2F");
|
||||||
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
||||||
let mut builder = client.get(url);
|
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||||
if let Some(token) = token {
|
if let Some(token) = token {
|
||||||
builder = builder.bearer_auth(token);
|
builder = builder.bearer_auth(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model_info = builder
|
let response = builder.send().await.ok()?;
|
||||||
.send()
|
|
||||||
.await
|
if response.status().is_success() {
|
||||||
.expect("Could not connect to hf.co")
|
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||||
.text()
|
}
|
||||||
.await
|
None
|
||||||
.expect("error when retrieving model info from hf.co");
|
|
||||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||||||
path = "/",
|
path = "/",
|
||||||
request_body = CompatGenerateRequest,
|
request_body = CompatGenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "See /generate or /generate_stream",
|
(status = 200, description = "Generated Text",
|
||||||
content(
|
content(
|
||||||
("application/json" = GenerateResponse),
|
("application/json" = GenerateResponse),
|
||||||
("text/event-stream" = StreamResponse),
|
("text/event-stream" = StreamResponse),
|
||||||
|
Loading…
Reference in New Issue
Block a user