Fix by using the actual real value as outputted by the validation

workers.
This commit is contained in:
Nicolas Patry 2024-01-11 15:18:58 +00:00
parent 5c8cc964fa
commit 8e0c538a18
2 changed files with 13 additions and 5 deletions

View File

@ -90,6 +90,7 @@ impl Infer {
) -> Result<
(
OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
@ -114,6 +115,7 @@ impl Infer {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = valid_request.input_length;
// Append the request to the queue
self.queue.append(Entry {
@ -130,7 +132,11 @@ impl Infer {
self.shared.batching_task.notify_one();
// Return stream
Ok((permit, UnboundedReceiverStream::new(response_rx)))
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
}
/// Add a new request to the queue and return a InferResponse
@ -142,7 +148,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
let (_permit, input_length, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -196,6 +202,7 @@ impl Infer {
{
Ok(InferResponse {
prefill: result_prefill,
input_length,
tokens: result_tokens,
generated_text,
queued,
@ -636,6 +643,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) input_length: u32,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,

View File

@ -170,7 +170,7 @@ async fn generate(
};
// Token details
let prompt_tokens = response.prefill.len();
let input_length = response.input_length;
let details = match details {
true => {
// convert best_of_responses
@ -258,7 +258,7 @@ async fn generate(
"x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(),
);
headers.insert("x-prompt-tokens", prompt_tokens.into());
headers.insert("x-prompt-tokens", input_length.into());
headers.insert(
"x-generated-tokens",
response.generated_text.generated_tokens.into(),
@ -384,7 +384,7 @@ async fn generate_stream(
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
Ok((_permit, _input_length, mut response_stream)) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
match response {