diff --git a/Cargo.lock b/Cargo.lock index 438b9603..60d47a6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,6 +2268,7 @@ dependencies = [ "opentelemetry-otlp", "parking_lot", "rand", + "reqwest", "serde", "serde_json", "text-generation-client", diff --git a/router/Cargo.toml b/router/Cargo.toml index fb424fbf..6e673af0 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" parking_lot = "0.12.1" rand = "0.8.5" +reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" diff --git a/router/src/lib.rs b/router/src/lib.rs index d7cfa4c7..78f9efd1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -40,13 +40,16 @@ pub(crate) struct GenerateParameters { example = 0.95 )] pub top_p: Option, - #[serde(default = "default_do_sample")] + #[serde(default)] #[schema(default = "false", example = true)] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] + #[schema(default = "None", example = false)] + pub return_full_text: Option, + #[serde(default)] #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] @@ -56,10 +59,6 @@ pub(crate) struct GenerateParameters { pub seed: Option, } -fn default_do_sample() -> bool { - false -} - fn default_max_new_tokens() -> u32 { 20 } @@ -70,8 +69,9 @@ fn default_parameters() -> GenerateParameters { repetition_penalty: None, top_k: None, top_p: None, - do_sample: default_do_sample(), + do_sample: false, max_new_tokens: default_max_new_tokens(), + return_full_text: None, stop: vec![], details: false, seed: None, diff --git a/router/src/main.rs b/router/src/main.rs index f1cf09a0..2baf9e72 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> { // This will only be used to validate payloads // // We need to download it outside of the Tokio runtime - let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); + let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap(); // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() @@ -97,6 +97,27 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); + // Get pipeline tag + let model_info = reqwest::get(format!( + "https://huggingface.co/api/models/{tokenizer_name}" + )) + .await + .expect("Could not connect to hf.co") + .text() + .await + .expect("error when retrieving model info from hf.co"); + let model_info: serde_json::Value = + serde_json::from_str(&model_info).expect("unable to parse model info"); + + // if pipeline-tag == text-generation we default to return_full_text = true + let compat_return_full_text = match model_info.get("pipeline_tag") { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + false + } + Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"), + }; + // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await @@ -113,6 +134,7 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( + compat_return_full_text, max_concurrent_requests, max_stop_sequences, max_input_length, diff --git a/router/src/server.rs b/router/src/server.rs index 83b0297e..75f84ba5 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -29,11 +29,18 @@ use utoipa_swagger_ui::SwaggerUi; /// Compatibility route with api-inference and AzureML #[instrument(skip(infer))] async fn compat_generate( + default_return_full_text: Extension, infer: Extension, req: Json, ) -> Result)> { + let mut req = req.0; + + // default return_full_text given the pipeline_tag + if req.parameters.return_full_text.is_none() { + req.parameters.return_full_text = Some(default_return_full_text.0) + } + // switch on stream - let req = req.0; if req.stream { Ok(generate_stream(infer, Json(req.into())) .await @@ -63,6 +70,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json { error = true; - yield Ok(Event::from(err)) + yield Ok(Event::from(err)); + break; } } } @@ -315,7 +347,7 @@ async fn generate_stream( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)) + yield Ok(Event::from(err)); } } // Check if generation reached the end @@ -324,7 +356,7 @@ async fn generate_stream( let err = InferError::IncompleteGeneration; metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); tracing::error!("{err}"); - yield Ok(Event::from(err)) + yield Ok(Event::from(err)); } }; @@ -345,6 +377,7 @@ async fn metrics(prom_handle: Extension) -> String { /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( + compat_return_full_text: bool, max_concurrent_requests: usize, max_stop_sequences: usize, max_input_length: usize, @@ -429,8 +462,9 @@ pub async fn run( .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) - .layer(Extension(infer)) .route("/metrics", get(metrics)) + .layer(Extension(compat_return_full_text)) + .layer(Extension(infer)) .layer(Extension(prom_handle)) .layer(opentelemetry_tracing_layer()) .layer(cors_layer);