diff --git a/router/src/server.rs b/router/src/server.rs index 9af94951b..e609821c8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})), )] #[instrument(skip(infer, req))] async fn compat_generate( - default_return_full_text: Extension, + Extension(default_return_full_text): Extension, infer: Extension, - req: Json, + Json(mut req): Json, ) -> Result)> { - let mut req = req.0; - // default return_full_text given the pipeline_tag if req.parameters.return_full_text.is_none() { - req.parameters.return_full_text = Some(default_return_full_text.0) + req.parameters.return_full_text = Some(default_return_full_text) } // switch on stream @@ -71,9 +69,9 @@ async fn compat_generate( .await .into_response()) } else { - let (headers, generation) = generate(infer, Json(req.into())).await?; + let (headers, Json(generation)) = generate(infer, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference - Ok((headers, Json(vec![generation.0])).into_response()) + Ok((headers, Json(vec![generation])).into_response()) } } @@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})), #[instrument( skip_all, fields( -parameters = ? req.0.parameters, +parameters = ? req.parameters, total_time, validation_time, queue_time, @@ -146,29 +144,29 @@ seed, )] async fn generate( infer: Extension, - req: Json, + Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); - tracing::debug!("Input: {}", req.0.inputs); + tracing::debug!("Input: {}", req.inputs); - let compute_characters = req.0.inputs.chars().count(); + let compute_characters = req.inputs.chars().count(); let mut add_prompt = None; - if req.0.parameters.return_full_text.unwrap_or(false) { - add_prompt = Some(req.0.inputs.clone()); + if req.parameters.return_full_text.unwrap_or(false) { + add_prompt = Some(req.inputs.clone()); } - let details = req.0.parameters.details || req.0.parameters.decoder_input_details; + let details = req.parameters.details || req.parameters.decoder_input_details; // Inference - let (response, best_of_responses) = match req.0.parameters.best_of { + let (response, best_of_responses) = match req.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?; (response, Some(best_of_responses)) } - _ => (infer.generate(req.0).await?, None), + _ => (infer.generate(req).await?, None), }; // Token details @@ -321,7 +319,7 @@ content_type = "text/event-stream"), #[instrument( skip_all, fields( -parameters = ? req.0.parameters, +parameters = ? req.parameters, total_time, validation_time, queue_time, @@ -331,8 +329,8 @@ seed, ) )] async fn generate_stream( - infer: Extension, - req: Json, + Extension(infer): Extension, + Json(req): Json, ) -> ( HeaderMap, Sse>>, @@ -341,9 +339,9 @@ async fn generate_stream( let start_time = Instant::now(); metrics::increment_counter!("tgi_request_count"); - tracing::debug!("Input: {}", req.0.inputs); + tracing::debug!("Input: {}", req.inputs); - let compute_characters = req.0.inputs.chars().count(); + let compute_characters = req.inputs.chars().count(); let mut headers = HeaderMap::new(); headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); @@ -359,24 +357,24 @@ async fn generate_stream( let mut error = false; let mut add_prompt = None; - if req.0.parameters.return_full_text.unwrap_or(false) { - add_prompt = Some(req.0.inputs.clone()); + if req.parameters.return_full_text.unwrap_or(false) { + add_prompt = Some(req.inputs.clone()); } - let details = req.0.parameters.details; + let details = req.parameters.details; - let best_of = req.0.parameters.best_of.unwrap_or(1); + let best_of = req.parameters.best_of.unwrap_or(1); if best_of != 1 { let err = InferError::from(ValidationError::BestOfStream); metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); - } else if req.0.parameters.decoder_input_details { + } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { - match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { + match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives Ok((_permit, mut response_stream)) => { // Server-Sent Event stream