From 8e0c538a18f1e99471b07fe65e1cee029f38ccea Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 11 Jan 2024 15:18:58 +0000 Subject: [PATCH] Fix by using the actual real value as outputted by the validation workers. --- router/src/infer.rs | 12 ++++++++++-- router/src/server.rs | 6 +++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index bf5920da..c770c5f9 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -90,6 +90,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, + u32, UnboundedReceiverStream>, ), InferError, @@ -114,6 +115,7 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -130,7 +132,11 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) } /// Add a new request to the queue and return a InferResponse @@ -142,7 +148,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, input_length, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -196,6 +202,7 @@ impl Infer { { Ok(InferResponse { prefill: result_prefill, + input_length, tokens: result_tokens, generated_text, queued, @@ -636,6 +643,7 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + pub(crate) input_length: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, diff --git a/router/src/server.rs b/router/src/server.rs index ef1f1cea..035626dd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -170,7 +170,7 @@ async fn generate( }; // Token details - let prompt_tokens = response.prefill.len(); + let input_length = response.input_length; let details = match details { true => { // convert best_of_responses @@ -258,7 +258,7 @@ async fn generate( "x-time-per-token", time_per_token.as_millis().to_string().parse().unwrap(), ); - headers.insert("x-prompt-tokens", prompt_tokens.into()); + headers.insert("x-prompt-tokens", input_length.into()); headers.insert( "x-generated-tokens", response.generated_text.generated_tokens.into(), @@ -384,7 +384,7 @@ async fn generate_stream( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _input_length, mut response_stream)) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { match response {