fix: prefer only intput_length over full ValidRequest in GenerateStreamResponse

This commit is contained in:
drbh 2024-01-11 10:46:55 -05:00
parent 62e6661616
commit c63551fad7
3 changed files with 5 additions and 7 deletions

View File

@ -223,7 +223,6 @@ impl Infer {
(result_generated_text, result_queued, result_start) (result_generated_text, result_queued, result_start)
{ {
Ok(InferResponse { Ok(InferResponse {
prompt_token_count: valid_request.input_length,
prefill: result_prefill, prefill: result_prefill,
_input_length, _input_length,
tokens: result_tokens, tokens: result_tokens,

View File

@ -5,7 +5,6 @@ mod queue;
pub mod server; pub mod server;
mod validation; mod validation;
use crate::validation::ValidGenerateRequest;
use infer::{Infer, InferError, InferStreamResponse}; use infer::{Infer, InferError, InferStreamResponse};
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -17,7 +16,7 @@ use validation::Validation;
/// Type alias for generation responses /// Type alias for generation responses
pub(crate) type GenerateStreamResponse = ( pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit, OwnedSemaphorePermit,
ValidGenerateRequest, u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
); );
@ -233,9 +232,9 @@ impl ChatCompletion {
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prompt_token_count, prompt_tokens: details.input_length,
completion_tokens: details.generated_tokens, completion_tokens: details.generated_tokens,
total_tokens: details.prompt_token_count + details.generated_tokens, total_tokens: details.input_length + details.generated_tokens,
}, },
} }
} }
@ -471,7 +470,7 @@ pub(crate) struct Details {
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>, pub top_tokens: Vec<Vec<Token>>,
#[schema(example = 1)] #[schema(example = 1)]
pub prompt_token_count: u32, pub input_length: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -208,7 +208,7 @@ async fn generate(
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens, top_tokens: response.top_tokens,
prompt_token_count: response.prompt_token_count, input_length: response.input_length,
}) })
} }
false => None, false => None,