fix: add prompt_token_count to InferResponse for chat completions

This commit is contained in:
drbh 2024-01-09 13:04:29 -05:00
parent adad67e3d0
commit fba1953eb6
4 changed files with 31 additions and 27 deletions

View File

@ -1,7 +1,7 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig; use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, PrefillToken}; use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use futures::future::try_join_all; use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
@ -14,7 +14,7 @@ use text_generation_client::{
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -92,13 +92,7 @@ impl Infer {
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result< ) -> Result<GenerateStreamResponse, InferError> {
(
OwnedSemaphorePermit,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
let permit = self let permit = self
.clone() .clone()
@ -122,7 +116,7 @@ impl Infer {
// Append the request to the queue // Append the request to the queue
self.queue.append(Entry { self.queue.append(Entry {
request: valid_request, request: valid_request.clone(),
response_tx, response_tx,
span: Span::current(), span: Span::current(),
temp_span: None, temp_span: None,
@ -135,7 +129,11 @@ impl Infer {
self.shared.batching_task.notify_one(); self.shared.batching_task.notify_one();
// Return stream // Return stream
Ok((permit, UnboundedReceiverStream::new(response_rx))) Ok((
permit,
valid_request,
UnboundedReceiverStream::new(response_rx),
))
} }
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
@ -155,7 +153,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?; let (_permit, valid_request, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
@ -208,6 +206,7 @@ impl Infer {
(result_generated_text, result_queued, result_start) (result_generated_text, result_queued, result_start)
{ {
Ok(InferResponse { Ok(InferResponse {
prompt_token_count: valid_request.input_length,
prefill: result_prefill, prefill: result_prefill,
tokens: result_tokens, tokens: result_tokens,
generated_text, generated_text,
@ -649,6 +648,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) prompt_token_count: u32,
pub(crate) prefill: Vec<PrefillToken>, pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>, pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,

View File

@ -5,12 +5,22 @@ mod queue;
pub mod server; pub mod server;
mod validation; mod validation;
use infer::Infer; use crate::validation::ValidGenerateRequest;
use infer::{Infer, InferError, InferStreamResponse};
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
ValidGenerateRequest,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
@ -214,12 +224,7 @@ pub(crate) struct Usage {
} }
impl ChatCompletion { impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self {
ouput: String,
created: u64,
details: Details,
prompt_character_count: u32,
) -> Self {
Self { Self {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -236,9 +241,9 @@ impl ChatCompletion {
finish_reason: None, finish_reason: None,
}], }],
usage: Usage { usage: Usage {
prompt_tokens: prompt_character_count, prompt_tokens: details.prompt_token_count,
completion_tokens: details.generated_tokens, completion_tokens: details.generated_tokens,
total_tokens: prompt_character_count + details.generated_tokens, total_tokens: details.prompt_token_count + details.generated_tokens,
}, },
} }
} }
@ -463,6 +468,8 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>, pub top_tokens: Vec<Vec<Token>>,
#[schema(example = 1)]
pub prompt_token_count: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -207,6 +207,7 @@ async fn generate(
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens, top_tokens: response.top_tokens,
prompt_token_count: response.prompt_token_count,
}) })
} }
false => None, false => None,
@ -395,7 +396,7 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => { Ok((_permit, _valid_request, mut response_stream)) => {
let mut index = 0; let mut index = 0;
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
@ -580,9 +581,6 @@ async fn chat_completions(
} }
}; };
// poor man's token count (assumes that each character is a token)
let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default();
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
@ -654,7 +652,6 @@ async fn chat_completions(
generation.generated_text, generation.generated_text,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
prompt_character_count,
); );
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference

View File

@ -376,7 +376,7 @@ type TokenizerRequest = (
Span, Span,
); );
#[derive(Debug)] #[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32, pub input_length: u32,