fix: add prompt_token_count to InferResponse for chat completions

This commit is contained in:
drbh 2024-01-09 13:04:29 -05:00
parent adad67e3d0
commit fba1953eb6
4 changed files with 31 additions and 27 deletions

View File

@ -1,7 +1,7 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, PrefillToken};
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
@ -14,7 +14,7 @@ use text_generation_client::{
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
@ -92,13 +92,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
@ -122,7 +116,7 @@ impl Infer {
// Append the request to the queue
self.queue.append(Entry {
request: valid_request,
request: valid_request.clone(),
response_tx,
span: Span::current(),
temp_span: None,
@ -135,7 +129,11 @@ impl Infer {
self.shared.batching_task.notify_one();
// Return stream
Ok((permit, UnboundedReceiverStream::new(response_rx)))
Ok((
permit,
valid_request,
UnboundedReceiverStream::new(response_rx),
))
}
/// Apply the chat template to the chat request
@ -155,7 +153,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
let (_permit, valid_request, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -208,6 +206,7 @@ impl Infer {
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prompt_token_count: valid_request.input_length,
prefill: result_prefill,
tokens: result_tokens,
generated_text,
@ -649,6 +648,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prompt_token_count: u32,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,

View File

@ -5,12 +5,22 @@ mod queue;
pub mod server;
mod validation;
use infer::Infer;
use crate::validation::ValidGenerateRequest;
use infer::{Infer, InferError, InferStreamResponse};
use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use utoipa::ToSchema;
use validation::Validation;
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
ValidGenerateRequest,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
@ -214,12 +224,7 @@ pub(crate) struct Usage {
}
impl ChatCompletion {
pub(crate) fn new(
ouput: String,
created: u64,
details: Details,
prompt_character_count: u32,
) -> Self {
pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self {
Self {
id: "".to_string(),
object: "text_completion".to_string(),
@ -236,9 +241,9 @@ impl ChatCompletion {
finish_reason: None,
}],
usage: Usage {
prompt_tokens: prompt_character_count,
prompt_tokens: details.prompt_token_count,
completion_tokens: details.generated_tokens,
total_tokens: prompt_character_count + details.generated_tokens,
total_tokens: details.prompt_token_count + details.generated_tokens,
},
}
}
@ -463,6 +468,8 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
#[schema(example = 1)]
pub prompt_token_count: u32,
}
#[derive(Serialize, ToSchema)]

View File

@ -207,6 +207,7 @@ async fn generate(
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
prompt_token_count: response.prompt_token_count,
})
}
false => None,
@ -395,7 +396,7 @@ async fn generate_stream_internal(
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
Ok((_permit, _valid_request, mut response_stream)) => {
let mut index = 0;
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
@ -580,9 +581,6 @@ async fn chat_completions(
}
};
// poor man's token count (assumes that each character is a token)
let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default();
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
@ -654,7 +652,6 @@ async fn chat_completions(
generation.generated_text,
current_time,
generation.details.unwrap(),
prompt_character_count,
);
// wrap generation inside a Vec to match api-inference

View File

@ -376,7 +376,7 @@ type TokenizerRequest = (
Span,
);
#[derive(Debug)]
#[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,