mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add new openapi schema
This commit is contained in:
parent
3225fed42e
commit
4fb69b9d1c
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "0.6.0"
|
||||
"version": "0.7.0-dev"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
@ -33,7 +33,19 @@
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "See /generate or /generate_stream"
|
||||
"description": "Generated Text",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/GenerateResponse"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/StreamResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Input validation error",
|
||||
@ -584,11 +596,73 @@
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model_id",
|
||||
"model_dtype",
|
||||
"model_device_type",
|
||||
"max_concurrent_requests",
|
||||
"max_best_of",
|
||||
"max_stop_sequences",
|
||||
"max_input_length",
|
||||
"max_total_tokens",
|
||||
"waiting_served_ratio",
|
||||
"max_batch_total_tokens",
|
||||
"max_waiting_tokens",
|
||||
"validation_workers",
|
||||
"version"
|
||||
],
|
||||
"properties": {
|
||||
"docker_label": {
|
||||
"type": "string",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"max_batch_total_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": "32000",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_best_of": {
|
||||
"type": "integer",
|
||||
"example": "2",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_concurrent_requests": {
|
||||
"type": "integer",
|
||||
"description": "Router Parameters",
|
||||
"example": "128",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_input_length": {
|
||||
"type": "integer",
|
||||
"example": "1024",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_stop_sequences": {
|
||||
"type": "integer",
|
||||
"example": "4",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_total_tokens": {
|
||||
"type": "integer",
|
||||
"example": "2048",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_waiting_tokens": {
|
||||
"type": "integer",
|
||||
"example": "20",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"model_device_type": {
|
||||
"type": "string",
|
||||
"example": "cuda"
|
||||
},
|
||||
"model_dtype": {
|
||||
"type": "string",
|
||||
"example": "torch.float16"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"description": "Model info",
|
||||
"example": "bigscience/blomm-560m"
|
||||
},
|
||||
"model_pipeline_tag": {
|
||||
@ -606,9 +680,20 @@
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"validation_workers": {
|
||||
"type": "integer",
|
||||
"example": "2",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"description": "Router Info",
|
||||
"example": "0.5.0"
|
||||
},
|
||||
"waiting_served_ratio": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"example": "1.2"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
@ -125,7 +126,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size{
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
||||
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
||||
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
||||
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||
}),
|
||||
};
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
||||
pub async fn get_model_info(
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
token: Option<String>,
|
||||
) -> Option<HubModelInfo> {
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace('/', "%2F");
|
||||
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
||||
let mut builder = client.get(url);
|
||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||
if let Some(token) = token {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
|
||||
let model_info = builder
|
||||
.send()
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
||||
let response = builder.send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||
}
|
||||
None
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||
path = "/",
|
||||
request_body = CompatGenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "See /generate or /generate_stream",
|
||||
(status = 200, description = "Generated Text",
|
||||
content(
|
||||
("application/json" = GenerateResponse),
|
||||
("text/event-stream" = StreamResponse),
|
||||
|
Loading…
Reference in New Issue
Block a user