mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: add prompt_token_count to InferResponse for chat completions
This commit is contained in:
parent
adad67e3d0
commit
fba1953eb6
@ -1,7 +1,7 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::{ChatRequest, GenerateRequest, PrefillToken};
|
||||
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use futures::future::try_join_all;
|
||||
use nohash_hasher::IntMap;
|
||||
@ -14,7 +14,7 @@ use text_generation_client::{
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
@ -92,13 +92,7 @@ impl Infer {
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let permit = self
|
||||
.clone()
|
||||
@ -122,7 +116,7 @@ impl Infer {
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request: valid_request,
|
||||
request: valid_request.clone(),
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
@ -135,7 +129,11 @@ impl Infer {
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
||||
Ok((
|
||||
permit,
|
||||
valid_request,
|
||||
UnboundedReceiverStream::new(response_rx),
|
||||
))
|
||||
}
|
||||
|
||||
/// Apply the chat template to the chat request
|
||||
@ -155,7 +153,7 @@ impl Infer {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
let (_permit, valid_request, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
@ -208,6 +206,7 @@ impl Infer {
|
||||
(result_generated_text, result_queued, result_start)
|
||||
{
|
||||
Ok(InferResponse {
|
||||
prompt_token_count: valid_request.input_length,
|
||||
prefill: result_prefill,
|
||||
tokens: result_tokens,
|
||||
generated_text,
|
||||
@ -649,6 +648,7 @@ pub(crate) enum InferStreamResponse {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) prompt_token_count: u32,
|
||||
pub(crate) prefill: Vec<PrefillToken>,
|
||||
pub(crate) tokens: Vec<Token>,
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
|
@ -5,12 +5,22 @@ mod queue;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use infer::Infer;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use infer::{Infer, InferError, InferStreamResponse};
|
||||
use queue::{Entry, Queue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
/// Type alias for generation responses
|
||||
pub(crate) type GenerateStreamResponse = (
|
||||
OwnedSemaphorePermit,
|
||||
ValidGenerateRequest,
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
/// Hub type
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct HubModelInfo {
|
||||
@ -214,12 +224,7 @@ pub(crate) struct Usage {
|
||||
}
|
||||
|
||||
impl ChatCompletion {
|
||||
pub(crate) fn new(
|
||||
ouput: String,
|
||||
created: u64,
|
||||
details: Details,
|
||||
prompt_character_count: u32,
|
||||
) -> Self {
|
||||
pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self {
|
||||
Self {
|
||||
id: "".to_string(),
|
||||
object: "text_completion".to_string(),
|
||||
@ -236,9 +241,9 @@ impl ChatCompletion {
|
||||
finish_reason: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt_character_count,
|
||||
prompt_tokens: details.prompt_token_count,
|
||||
completion_tokens: details.generated_tokens,
|
||||
total_tokens: prompt_character_count + details.generated_tokens,
|
||||
total_tokens: details.prompt_token_count + details.generated_tokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -463,6 +468,8 @@ pub(crate) struct Details {
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
#[schema(example = 1)]
|
||||
pub prompt_token_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -207,6 +207,7 @@ async fn generate(
|
||||
seed: response.generated_text.seed,
|
||||
best_of_sequences,
|
||||
top_tokens: response.top_tokens,
|
||||
prompt_token_count: response.prompt_token_count,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
@ -395,7 +396,7 @@ async fn generate_stream_internal(
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
Ok((_permit, _valid_request, mut response_stream)) => {
|
||||
let mut index = 0;
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
@ -580,9 +581,6 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// poor man's token count (assumes that each character is a token)
|
||||
let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default();
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
@ -654,7 +652,6 @@ async fn chat_completions(
|
||||
generation.generated_text,
|
||||
current_time,
|
||||
generation.details.unwrap(),
|
||||
prompt_character_count,
|
||||
);
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
|
@ -376,7 +376,7 @@ type TokenizerRequest = (
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub input_length: u32,
|
||||
|
Loading…
Reference in New Issue
Block a user