diff --git a/Cargo.lock b/Cargo.lock index 438b9603..60d47a6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,6 +2268,7 @@ dependencies = [ "opentelemetry-otlp", "parking_lot", "rand", + "reqwest", "serde", "serde_json", "text-generation-client", diff --git a/router/Cargo.toml b/router/Cargo.toml index fb424fbf..6e673af0 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" parking_lot = "0.12.1" rand = "0.8.5" +reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" serde_json = "1.0.93" thiserror = "1.0.38" diff --git a/router/src/main.rs b/router/src/main.rs index f1cf09a0..6b652af3 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> { // This will only be used to validate payloads // // We need to download it outside of the Tokio runtime - let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); + let tokenizer = Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap(); // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() @@ -97,6 +97,36 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { init_logging(otlp_endpoint, json_output); + // Get pipeline tag + let model_info = reqwest::get(format!( + "https://api-inference.huggingface.co/models/{tokenizer_name}" + )) + .await + .expect("Could not connect to hf.co") + .text() + .await + .expect("error when retrieving model info from hf.co"); + let model_info: serde_json::Value = + serde_json::from_str(&model_info).expect("unable to parse model info"); + + // if pipeline-tag == text-generation we return prompt + generated_text from the / route + let compat_return_full_text = match model_info["pipeline_tag"].as_str() { + None => { + tracing::warn!("no pipeline tag found for model {tokenizer_name}"); + tracing::warn!("returning only generated_text from the compat route"); + false + } + Some(pipeline_tag) => { + if pipeline_tag == "text-generation" { + tracing::info!("returning prompt + generated_text from the compat route"); + true + } else { + tracing::info!("returning only generated_text from the compat route"); + false + } + } + }; + // Instantiate sharded client from the master unix socket let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await @@ -113,6 +143,7 @@ fn main() -> Result<(), std::io::Error> { // Run server server::run( + compat_return_full_text, max_concurrent_requests, max_stop_sequences, max_input_length, diff --git a/router/src/server.rs b/router/src/server.rs index 83b0297e..e33b0e2d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -29,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi; /// Compatibility route with api-inference and AzureML #[instrument(skip(infer))] async fn compat_generate( + return_full_text: Extension, infer: Extension, req: Json, ) -> Result)> { @@ -39,9 +40,20 @@ async fn compat_generate( .await .into_response()) } else { + let mut add_prompt = None; + if return_full_text.0 { + add_prompt = Some(req.inputs.clone()); + } + let (headers, generation) = generate(infer, Json(req.into())).await?; + + let mut generation = generation.0; + if let Some(prompt) = add_prompt { + generation.generated_text = prompt + &generation.generated_text; + }; + // wrap generation inside a Vec to match api-inference - Ok((headers, Json(vec![generation.0])).into_response()) + Ok((headers, Json(vec![generation])).into_response()) } } @@ -345,6 +357,7 @@ async fn metrics(prom_handle: Extension) -> String { /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( + compat_return_full_text: bool, max_concurrent_requests: usize, max_stop_sequences: usize, max_input_length: usize, @@ -429,8 +442,9 @@ pub async fn run( .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) - .layer(Extension(infer)) .route("/metrics", get(metrics)) + .layer(Extension(compat_return_full_text)) + .layer(Extension(infer)) .layer(Extension(prom_handle)) .layer(opentelemetry_tracing_layer()) .layer(cors_layer);