Fix by using the actual real value as outputted by the validation

workers.
This commit is contained in:
Nicolas Patry 2024-01-11 15:18:58 +00:00
parent 5c8cc964fa
commit 8e0c538a18
2 changed files with 13 additions and 5 deletions

View File

@ -90,6 +90,7 @@ impl Infer {
) -> Result< ) -> Result<
( (
OwnedSemaphorePermit, OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
), ),
InferError, InferError,
@ -114,6 +115,7 @@ impl Infer {
// MPSC channel to communicate with the background batching task // MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = valid_request.input_length;
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
@ -130,7 +132,11 @@ impl Infer {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
Ok((permit, UnboundedReceiverStream::new(response_rx))) Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
@ -142,7 +148,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?; let (_permit, input_length, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -196,6 +202,7 @@ impl Infer {
{ {
Ok(InferResponse { Ok(InferResponse {
prefill: result_prefill, prefill: result_prefill,
input_length,
tokens: result_tokens, tokens: result_tokens,
generated_text, generated_text,
queued, queued,
@ -636,6 +643,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) input_length: u32,
pub(crate) prefill: Vec<PrefillToken>, pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>, pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,

View File

@ -170,7 +170,7 @@ async fn generate(
}; };
// Token details // Token details
let prompt_tokens = response.prefill.len(); let input_length = response.input_length;
let details = match details { let details = match details {
true => { true => {
// convert best_of_responses // convert best_of_responses
@ -258,7 +258,7 @@ async fn generate(
"x-time-per-token", "x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(), time_per_token.as_millis().to_string().parse().unwrap(),
); );
headers.insert("x-prompt-tokens", prompt_tokens.into()); headers.insert("x-prompt-tokens", input_length.into());
headers.insert( headers.insert(
"x-generated-tokens", "x-generated-tokens",
response.generated_text.generated_tokens.into(), response.generated_text.generated_tokens.into(),
@ -384,7 +384,7 @@ async fn generate_stream(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, _input_length, mut response_stream)) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
match response { match response {