mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fix by using the actual real value as outputted by the validation
workers.
This commit is contained in:
parent
5c8cc964fa
commit
8e0c538a18
@ -90,6 +90,7 @@ impl Infer {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
u32,
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
),
|
||||
InferError,
|
||||
@ -114,6 +115,7 @@ impl Infer {
|
||||
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let input_length = valid_request.input_length;
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
@ -130,7 +132,11 @@ impl Infer {
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
||||
Ok((
|
||||
permit,
|
||||
input_length,
|
||||
UnboundedReceiverStream::new(response_rx),
|
||||
))
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
@ -142,7 +148,7 @@ impl Infer {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
let (_permit, input_length, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
@ -196,6 +202,7 @@ impl Infer {
|
||||
{
|
||||
Ok(InferResponse {
|
||||
prefill: result_prefill,
|
||||
input_length,
|
||||
tokens: result_tokens,
|
||||
generated_text,
|
||||
queued,
|
||||
@ -636,6 +643,7 @@ pub(crate) enum InferStreamResponse {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) input_length: u32,
|
||||
pub(crate) prefill: Vec<PrefillToken>,
|
||||
pub(crate) tokens: Vec<Token>,
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
|
@ -170,7 +170,7 @@ async fn generate(
|
||||
};
|
||||
|
||||
// Token details
|
||||
let prompt_tokens = response.prefill.len();
|
||||
let input_length = response.input_length;
|
||||
let details = match details {
|
||||
true => {
|
||||
// convert best_of_responses
|
||||
@ -258,7 +258,7 @@ async fn generate(
|
||||
"x-time-per-token",
|
||||
time_per_token.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("x-prompt-tokens", prompt_tokens.into());
|
||||
headers.insert("x-prompt-tokens", input_length.into());
|
||||
headers.insert(
|
||||
"x-generated-tokens",
|
||||
response.generated_text.generated_tokens.into(),
|
||||
@ -384,7 +384,7 @@ async fn generate_stream(
|
||||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
Ok((_permit, _input_length, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
|
Loading…
Reference in New Issue
Block a user