diff --git a/router/src/infer.rs b/router/src/infer.rs index 8078cee7..fdf0ae77 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -101,6 +101,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, + u32, UnboundedReceiverStream>, ), InferError, @@ -125,6 +126,7 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -141,7 +143,11 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) } /// Add a new request to the queue and return a InferResponse @@ -153,7 +159,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -207,6 +213,7 @@ impl Infer { { Ok(InferResponse { prefill: result_prefill, + _input_length, tokens: result_tokens, generated_text, queued, @@ -647,6 +654,10 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, diff --git a/router/src/server.rs b/router/src/server.rs index 78e2af3b..1ec45563 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -172,6 +172,7 @@ async fn generate( }; // Token details + let input_length = response._input_length; let details = match details { true => { // convert best_of_responses @@ -259,6 +260,11 @@ async fn generate( "x-time-per-token", time_per_token.as_millis().to_string().parse().unwrap(), ); + headers.insert("x-prompt-tokens", input_length.into()); + headers.insert( + "x-generated-tokens", + response.generated_text.generated_tokens.into(), + ); // Metrics metrics::increment_counter!("tgi_request_success"); @@ -380,7 +386,7 @@ async fn generate_stream( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _input_length, mut response_stream)) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { match response {