From 12cfc7930bc7795f99828e9c0f19a31b4fccf2c7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 11 Jan 2024 19:01:43 +0100 Subject: [PATCH] Return prompt vs generated tokens. (#1436) # What does this PR do? Fixes #637 Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/infer.rs | 15 +++++++++++++-- router/src/server.rs | 8 +++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 8078cee7..fdf0ae77 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -101,6 +101,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, + u32, UnboundedReceiverStream>, ), InferError, @@ -125,6 +126,7 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -141,7 +143,11 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) } /// Add a new request to the queue and return a InferResponse @@ -153,7 +159,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -207,6 +213,7 @@ impl Infer { { Ok(InferResponse { prefill: result_prefill, + _input_length, tokens: result_tokens, generated_text, queued, @@ -647,6 +654,10 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, diff --git a/router/src/server.rs b/router/src/server.rs index 78e2af3b..1ec45563 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -172,6 +172,7 @@ async fn generate( }; // Token details + let input_length = response._input_length; let details = match details { true => { // convert best_of_responses @@ -259,6 +260,11 @@ async fn generate( "x-time-per-token", time_per_token.as_millis().to_string().parse().unwrap(), ); + headers.insert("x-prompt-tokens", input_length.into()); + headers.insert( + "x-generated-tokens", + response.generated_text.generated_tokens.into(), + ); // Metrics metrics::increment_counter!("tgi_request_success"); @@ -380,7 +386,7 @@ async fn generate_stream( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _input_length, mut response_stream)) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { match response {