diff --git a/router/src/infer.rs b/router/src/infer.rs index ec91b54e..0c058f12 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::HubTokenizerConfig; -use crate::{ChatRequest, GenerateRequest, PrefillToken}; +use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; use crate::{Entry, Queue, Token}; use futures::future::try_join_all; use nohash_hasher::IntMap; @@ -14,7 +14,7 @@ use text_generation_client::{ }; use thiserror::Error; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -92,13 +92,7 @@ impl Infer { pub(crate) async fn generate_stream( &self, request: GenerateRequest, - ) -> Result< - ( - OwnedSemaphorePermit, - UnboundedReceiverStream>, - ), - InferError, - > { + ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let permit = self .clone() @@ -122,7 +116,7 @@ impl Infer { // Append the request to the queue self.queue.append(Entry { - request: valid_request, + request: valid_request.clone(), response_tx, span: Span::current(), temp_span: None, @@ -135,7 +129,11 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + valid_request, + UnboundedReceiverStream::new(response_rx), + )) } /// Apply the chat template to the chat request @@ -155,7 +153,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, valid_request, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -208,6 +206,7 @@ impl Infer { (result_generated_text, result_queued, result_start) { Ok(InferResponse { + prompt_token_count: valid_request.input_length, prefill: result_prefill, tokens: result_tokens, generated_text, @@ -649,6 +648,7 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + pub(crate) prompt_token_count: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, diff --git a/router/src/lib.rs b/router/src/lib.rs index 3ffc1275..7d7461fd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -5,12 +5,22 @@ mod queue; pub mod server; mod validation; -use infer::Infer; +use crate::validation::ValidGenerateRequest; +use infer::{Infer, InferError, InferStreamResponse}; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; +use tokio::sync::OwnedSemaphorePermit; +use tokio_stream::wrappers::UnboundedReceiverStream; use utoipa::ToSchema; use validation::Validation; +/// Type alias for generation responses +pub(crate) type GenerateStreamResponse = ( + OwnedSemaphorePermit, + ValidGenerateRequest, + UnboundedReceiverStream>, +); + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -214,12 +224,7 @@ pub(crate) struct Usage { } impl ChatCompletion { - pub(crate) fn new( - ouput: String, - created: u64, - details: Details, - prompt_character_count: u32, - ) -> Self { + pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self { Self { id: "".to_string(), object: "text_completion".to_string(), @@ -236,9 +241,9 @@ impl ChatCompletion { finish_reason: None, }], usage: Usage { - prompt_tokens: prompt_character_count, + prompt_tokens: details.prompt_token_count, completion_tokens: details.generated_tokens, - total_tokens: prompt_character_count + details.generated_tokens, + total_tokens: details.prompt_token_count + details.generated_tokens, }, } } @@ -463,6 +468,8 @@ pub(crate) struct Details { pub best_of_sequences: Option>, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec>, + #[schema(example = 1)] + pub prompt_token_count: u32, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 47880678..c44aeefe 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -207,6 +207,7 @@ async fn generate( seed: response.generated_text.seed, best_of_sequences, top_tokens: response.top_tokens, + prompt_token_count: response.prompt_token_count, }) } false => None, @@ -395,7 +396,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _valid_request, mut response_stream)) => { let mut index = 0; // Server-Sent Event stream while let Some(response) = response_stream.next().await { @@ -580,9 +581,6 @@ async fn chat_completions( } }; - // poor man's token count (assumes that each character is a token) - let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default(); - // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), @@ -654,7 +652,6 @@ async fn chat_completions( generation.generated_text, current_time, generation.details.unwrap(), - prompt_character_count, ); // wrap generation inside a Vec to match api-inference diff --git a/router/src/validation.rs b/router/src/validation.rs index 64f25c82..370e9588 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -376,7 +376,7 @@ type TokenizerRequest = ( Span, ); -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { pub inputs: String, pub input_length: u32,