diff --git a/router/src/lib.rs b/router/src/lib.rs index 35c22763..78f9efd1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -47,8 +47,8 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(default = "false", example = false)] - pub return_full_text: bool, + #[schema(default = "None", example = false)] + pub return_full_text: Option, #[serde(default)] #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, @@ -71,7 +71,7 @@ fn default_parameters() -> GenerateParameters { top_p: None, do_sample: false, max_new_tokens: default_max_new_tokens(), - return_full_text: false, + return_full_text: None, stop: vec![], details: false, seed: None, diff --git a/router/src/server.rs b/router/src/server.rs index 289d1899..af50c4c0 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -29,31 +29,27 @@ use utoipa_swagger_ui::SwaggerUi; /// Compatibility route with api-inference and AzureML #[instrument(skip(infer))] async fn compat_generate( - return_full_text: Extension, + default_return_full_text: Extension, infer: Extension, req: Json, ) -> Result)> { // switch on stream - let req = req.0; + let mut req = req.0; + if req.parameters.return_full_text.is_none() { + req.parameters.return_full_text = Some(default_return_full_text.0) + } + if req.stream { - Ok(generate_stream(infer, Json(req.into())) - .await - .into_response()) + Ok( + generate_stream(infer, Json(req.into())) + .await + .into_response(), + ) } else { - let mut add_prompt = None; - if return_full_text.0 { - add_prompt = Some(req.inputs.clone()); - } - - let (headers, generation) = generate(infer, Json(req.into())).await?; - - let mut generation = generation.0; - if let Some(prompt) = add_prompt { - generation.generated_text = prompt + &generation.generated_text; - }; - + let (headers, generation) = + generate(infer, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference - Ok((headers, Json(vec![generation])).into_response()) + Ok((headers, Json(vec![generation.0])).into_response()) } } @@ -75,7 +71,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json, @@ -122,7 +118,12 @@ async fn generate( let start_time = Instant::now(); let mut add_prompt = None; - if req.0.parameters.return_full_text { + if req + .0 + .parameters + .return_full_text + .unwrap_or(false) + { add_prompt = Some(req.0.inputs.clone()); } @@ -209,42 +210,42 @@ async fn generate( /// Generate a stream of token using Server-Sent Events #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/generate_stream", - request_body = GenerateRequest, - responses( - (status = 200, description = "Generated Text", body = StreamResponse, - content_type="text/event-stream"), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json!({"error": "Request failed during generation"}), - content_type="text/event-stream"), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json!({"error": "Model is overloaded"}), - content_type="text/event-stream"), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json!({"error": "Input validation error"}), - content_type="text/event-stream"), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json!({"error": "Incomplete generation"}), - content_type="text/event-stream"), - ) +post, +tag = "Text Generation Inference", +path = "/generate_stream", +request_body = GenerateRequest, +responses( +(status = 200, description = "Generated Text", body = StreamResponse, +content_type = "text/event-stream"), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"}), +content_type = "text/event-stream"), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"}), +content_type = "text/event-stream"), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"}), +content_type = "text/event-stream"), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"}), +content_type = "text/event-stream"), +) )] #[instrument( - skip(infer), - fields( - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) +skip(infer), +fields( +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) )] async fn generate_stream( infer: Extension, req: Json, -) -> Sse>> { +) -> Sse>> { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -254,7 +255,7 @@ async fn generate_stream( let mut error = false; let mut add_prompt = None; - if req.0.parameters.return_full_text { + if req.0.parameters.return_full_text.unwrap_or(false) { add_prompt = Some(req.0.inputs.clone()); } let details = req.0.parameters.details; @@ -370,10 +371,10 @@ async fn generate_stream( /// Prometheus metrics scrape endpoint #[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/metrics", - responses((status = 200, description = "Prometheus Metrics", body = String)) +get, +tag = "Text Generation Inference", +path = "/metrics", +responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -398,35 +399,35 @@ pub async fn run( // OpenAPI documentation #[derive(OpenApi)] #[openapi( - paths( - generate, - generate_stream, - metrics, - ), - components( - schemas( - GenerateRequest, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) + paths( + generate, + generate_stream, + metrics, + ), + components( + schemas( + GenerateRequest, + GenerateParameters, + PrefillToken, + Token, + GenerateResponse, + Details, + FinishReason, + StreamResponse, + StreamDetails, + ErrorResponse, + ) + ), + tags( + (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") + ), + info( + title = "Text Generation Inference", + license( + name = "Apache 2.0", + url = "https://www.apache.org/licenses/LICENSE-2.0" + ) + ) )] struct ApiDoc; @@ -492,7 +493,7 @@ async fn shutdown_signal() { }; #[cfg(unix)] - let terminate = async { + let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() @@ -500,7 +501,7 @@ async fn shutdown_signal() { }; #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {},