mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: add prompt_token_count to InferResponse for chat completions
This commit is contained in:
parent
adad67e3d0
commit
fba1953eb6
@ -1,7 +1,7 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::HubTokenizerConfig;
|
use crate::HubTokenizerConfig;
|
||||||
use crate::{ChatRequest, GenerateRequest, PrefillToken};
|
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
@ -14,7 +14,7 @@ use text_generation_client::{
|
|||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
@ -92,13 +92,7 @@ impl Infer {
|
|||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<
|
) -> Result<GenerateStreamResponse, InferError> {
|
||||||
(
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
),
|
|
||||||
InferError,
|
|
||||||
> {
|
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
@ -122,7 +116,7 @@ impl Infer {
|
|||||||
|
|
||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
request: valid_request,
|
request: valid_request.clone(),
|
||||||
response_tx,
|
response_tx,
|
||||||
span: Span::current(),
|
span: Span::current(),
|
||||||
temp_span: None,
|
temp_span: None,
|
||||||
@ -135,7 +129,11 @@ impl Infer {
|
|||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
Ok((
|
||||||
|
permit,
|
||||||
|
valid_request,
|
||||||
|
UnboundedReceiverStream::new(response_rx),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
@ -155,7 +153,7 @@ impl Infer {
|
|||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
let (_permit, valid_request, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -208,6 +206,7 @@ impl Infer {
|
|||||||
(result_generated_text, result_queued, result_start)
|
(result_generated_text, result_queued, result_start)
|
||||||
{
|
{
|
||||||
Ok(InferResponse {
|
Ok(InferResponse {
|
||||||
|
prompt_token_count: valid_request.input_length,
|
||||||
prefill: result_prefill,
|
prefill: result_prefill,
|
||||||
tokens: result_tokens,
|
tokens: result_tokens,
|
||||||
generated_text,
|
generated_text,
|
||||||
@ -649,6 +648,7 @@ pub(crate) enum InferStreamResponse {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
|
pub(crate) prompt_token_count: u32,
|
||||||
pub(crate) prefill: Vec<PrefillToken>,
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
pub(crate) tokens: Vec<Token>,
|
pub(crate) tokens: Vec<Token>,
|
||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
|
@ -5,12 +5,22 @@ mod queue;
|
|||||||
pub mod server;
|
pub mod server;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use infer::Infer;
|
use crate::validation::ValidGenerateRequest;
|
||||||
|
use infer::{Infer, InferError, InferStreamResponse};
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::sync::OwnedSemaphorePermit;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
/// Type alias for generation responses
|
||||||
|
pub(crate) type GenerateStreamResponse = (
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
ValidGenerateRequest,
|
||||||
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
|
);
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
@ -214,12 +224,7 @@ pub(crate) struct Usage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletion {
|
impl ChatCompletion {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(ouput: String, created: u64, details: Details) -> Self {
|
||||||
ouput: String,
|
|
||||||
created: u64,
|
|
||||||
details: Details,
|
|
||||||
prompt_character_count: u32,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
@ -236,9 +241,9 @@ impl ChatCompletion {
|
|||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: prompt_character_count,
|
prompt_tokens: details.prompt_token_count,
|
||||||
completion_tokens: details.generated_tokens,
|
completion_tokens: details.generated_tokens,
|
||||||
total_tokens: prompt_character_count + details.generated_tokens,
|
total_tokens: details.prompt_token_count + details.generated_tokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -463,6 +468,8 @@ pub(crate) struct Details {
|
|||||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub top_tokens: Vec<Vec<Token>>,
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub prompt_token_count: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -207,6 +207,7 @@ async fn generate(
|
|||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
best_of_sequences,
|
best_of_sequences,
|
||||||
top_tokens: response.top_tokens,
|
top_tokens: response.top_tokens,
|
||||||
|
prompt_token_count: response.prompt_token_count,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None,
|
false => None,
|
||||||
@ -395,7 +396,7 @@ async fn generate_stream_internal(
|
|||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, _valid_request, mut response_stream)) => {
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
@ -580,9 +581,6 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// poor man's token count (assumes that each character is a token)
|
|
||||||
let prompt_character_count: u32 = inputs.chars().count().try_into().unwrap_or_default();
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
@ -654,7 +652,6 @@ async fn chat_completions(
|
|||||||
generation.generated_text,
|
generation.generated_text,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
prompt_character_count,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
@ -376,7 +376,7 @@ type TokenizerRequest = (
|
|||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
|
Loading…
Reference in New Issue
Block a user