mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fix by using the actual real value as outputted by the validation
workers.
This commit is contained in:
parent
5c8cc964fa
commit
8e0c538a18
@ -90,6 +90,7 @@ impl Infer {
|
|||||||
) -> Result<
|
) -> Result<
|
||||||
(
|
(
|
||||||
OwnedSemaphorePermit,
|
OwnedSemaphorePermit,
|
||||||
|
u32,
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||||
),
|
),
|
||||||
InferError,
|
InferError,
|
||||||
@ -114,6 +115,7 @@ impl Infer {
|
|||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
let input_length = valid_request.input_length;
|
||||||
|
|
||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
@ -130,7 +132,11 @@ impl Infer {
|
|||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
Ok((
|
||||||
|
permit,
|
||||||
|
input_length,
|
||||||
|
UnboundedReceiverStream::new(response_rx),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
@ -142,7 +148,7 @@ impl Infer {
|
|||||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
let (_permit, input_length, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
@ -196,6 +202,7 @@ impl Infer {
|
|||||||
{
|
{
|
||||||
Ok(InferResponse {
|
Ok(InferResponse {
|
||||||
prefill: result_prefill,
|
prefill: result_prefill,
|
||||||
|
input_length,
|
||||||
tokens: result_tokens,
|
tokens: result_tokens,
|
||||||
generated_text,
|
generated_text,
|
||||||
queued,
|
queued,
|
||||||
@ -636,6 +643,7 @@ pub(crate) enum InferStreamResponse {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
|
pub(crate) input_length: u32,
|
||||||
pub(crate) prefill: Vec<PrefillToken>,
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
pub(crate) tokens: Vec<Token>,
|
pub(crate) tokens: Vec<Token>,
|
||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
|
@ -170,7 +170,7 @@ async fn generate(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let prompt_tokens = response.prefill.len();
|
let input_length = response.input_length;
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => {
|
true => {
|
||||||
// convert best_of_responses
|
// convert best_of_responses
|
||||||
@ -258,7 +258,7 @@ async fn generate(
|
|||||||
"x-time-per-token",
|
"x-time-per-token",
|
||||||
time_per_token.as_millis().to_string().parse().unwrap(),
|
time_per_token.as_millis().to_string().parse().unwrap(),
|
||||||
);
|
);
|
||||||
headers.insert("x-prompt-tokens", prompt_tokens.into());
|
headers.insert("x-prompt-tokens", input_length.into());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-generated-tokens",
|
"x-generated-tokens",
|
||||||
response.generated_text.generated_tokens.into(),
|
response.generated_text.generated_tokens.into(),
|
||||||
@ -384,7 +384,7 @@ async fn generate_stream(
|
|||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
Ok((_permit, mut response_stream)) => {
|
Ok((_permit, _input_length, mut response_stream)) => {
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
while let Some(response) = response_stream.next().await {
|
while let Some(response) = response_stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
|
Loading…
Reference in New Issue
Block a user