diff --git a/docs/openapi.json b/docs/openapi.json index e7b7b140..fbac7786 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.6.0" + "version": "0.7.0-dev" }, "paths": { "/": { @@ -33,7 +33,19 @@ }, "responses": { "200": { - "description": "See /generate or /generate_stream" + "description": "Generated Text", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateResponse" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/StreamResponse" + } + } + } }, "422": { "description": "Input validation error", @@ -584,11 +596,73 @@ "type": "object", "required": [ "model_id", + "model_dtype", + "model_device_type", + "max_concurrent_requests", + "max_best_of", + "max_stop_sequences", + "max_input_length", + "max_total_tokens", + "waiting_served_ratio", + "max_batch_total_tokens", + "max_waiting_tokens", + "validation_workers", "version" ], "properties": { + "docker_label": { + "type": "string", + "example": "null", + "nullable": true + }, + "max_batch_total_tokens": { + "type": "integer", + "format": "int32", + "example": "32000", + "minimum": 0.0 + }, + "max_best_of": { + "type": "integer", + "example": "2", + "minimum": 0.0 + }, + "max_concurrent_requests": { + "type": "integer", + "description": "Router Parameters", + "example": "128", + "minimum": 0.0 + }, + "max_input_length": { + "type": "integer", + "example": "1024", + "minimum": 0.0 + }, + "max_stop_sequences": { + "type": "integer", + "example": "4", + "minimum": 0.0 + }, + "max_total_tokens": { + "type": "integer", + "example": "2048", + "minimum": 0.0 + }, + "max_waiting_tokens": { + "type": "integer", + "example": "20", + "minimum": 0.0 + }, + "model_device_type": { + "type": "string", + "example": "cuda" + }, + "model_dtype": { + "type": "string", + "example": "torch.float16" + }, "model_id": { "type": "string", + "description": "Model info", "example": "bigscience/blomm-560m" }, "model_pipeline_tag": { @@ -606,9 +680,20 @@ "example": "null", "nullable": true }, + "validation_workers": { + "type": "integer", + "example": "2", + "minimum": 0.0 + }, "version": { "type": "string", + "description": "Router Info", "example": "0.5.0" + }, + "waiting_served_ratio": { + "type": "number", + "format": "float", + "example": "1.2" } } }, diff --git a/router/src/main.rs b/router/src/main.rs index 5788897d..5ad49003 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::Path; +use std::time::Duration; use text_generation_client::ShardedClient; use text_generation_router::{server, HubModelInfo}; use tokenizers::{FromPretrainedParameters, Tokenizer}; @@ -125,7 +126,7 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); - if let Some(max_batch_size) = max_batch_size{ + if let Some(max_batch_size) = max_batch_size { tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); @@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token).await, + false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({ + tracing::warn!("Could not retrieve model info from the Hugging Face hub."); + HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } + }), }; // if pipeline-tag == text-generation we default to return_full_text = true @@ -195,7 +199,7 @@ fn main() -> Result<(), std::io::Error> { addr, cors_allow_origin, ) - .await; + .await; Ok(()) }) } @@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } /// get model info from the Huggingface Hub -pub async fn get_model_info(model_id: &str, revision: &str, token: Option) -> HubModelInfo { +pub async fn get_model_info( + model_id: &str, + revision: &str, + token: Option, +) -> Option { let client = reqwest::Client::new(); // Poor man's urlencode let revision = revision.replace('/', "%2F"); let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}"); - let mut builder = client.get(url); + let mut builder = client.get(url).timeout(Duration::from_secs(5)); if let Some(token) = token { builder = builder.bearer_auth(token); } - let model_info = builder - .send() - .await - .expect("Could not connect to hf.co") - .text() - .await - .expect("error when retrieving model info from hf.co"); - serde_json::from_str(&model_info).expect("unable to parse model info") + let response = builder.send().await.ok()?; + + if response.status().is_success() { + return serde_json::from_str(&response.text().await.ok()?).ok(); + } + None } diff --git a/router/src/server.rs b/router/src/server.rs index a9baf288..162b6fd9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi; path = "/", request_body = CompatGenerateRequest, responses( - (status = 200, description = "See /generate or /generate_stream", + (status = 200, description = "Generated Text", content( ("application/json" = GenerateResponse), ("text/event-stream" = StreamResponse),