add best of sequences to details

This commit is contained in:
OlivierDehaene 2023-03-09 14:27:39 +01:00
parent 9624d4060f
commit 9f4f2fc8e3
3 changed files with 145 additions and 90 deletions

View File

@ -185,28 +185,35 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
best_of: usize, best_of: usize,
) -> Result<InferResponse, InferError> { ) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately // validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?; let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests // create multiple generate requests
let infer_responses: Vec<InferResponse> = let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token // get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN; let mut max_logprob: f32 = f32::MIN;
let mut best_response = None;
for response in infer_responses { for (i, response) in infer_responses.iter().enumerate() {
// sum logprobs of the generated tokens // mean logprobs of the generated tokens
let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum(); let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence // set best sequence
if sequence_logprob > max_logprob { if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob; max_logprob = sequence_logprob;
best_response = Some(response);
} }
} }
Ok(best_response.expect("best_response is None. This is a bug.")) let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
} }
} }

View File

@ -168,6 +168,20 @@ pub(crate) enum FinishReason {
StopSequence, StopSequence,
} }
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
#[schema(example = "test")]
pub generated_text: String,
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
#[schema(example = "length")] #[schema(example = "length")]
@ -178,6 +192,8 @@ pub(crate) struct Details {
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>, pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]

View File

@ -1,10 +1,10 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
Validation, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -87,21 +87,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate", path = "/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = GenerateResponse), (status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})), example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})), example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument( #[instrument(
skip(infer), skip(infer),
@ -130,20 +130,51 @@ async fn generate(
let details = req.0.parameters.details; let details = req.0.parameters.details;
// Inference // Inference
let response = match req.0.parameters.best_of { let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => infer.generate_best_of(req.0, best_of).await?, Some(best_of) if best_of > 1 => {
_ => infer.generate(req.0).await?, let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
}; };
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => {
finish_reason: FinishReason::from(response.generated_text.finish_reason), // convert best_of_responses
generated_tokens: response.generated_text.generated_tokens, let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
prefill: response.prefill, responses
tokens: response.tokens, .into_iter()
seed: response.generated_text.seed, .map(|response: InferResponse| {
}), // Add prompt if return_full_text
let mut output_text = response.generated_text.text;
if let Some(prompt) = &add_prompt {
output_text = prompt.clone() + &output_text;
}
BestOfSequence {
generated_text: output_text,
finish_reason: FinishReason::from(
response.generated_text.finish_reason,
),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}
})
.collect()
});
Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
})
}
false => None, false => None,
}; };
@ -222,26 +253,26 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events /// Generate a stream of token using Server-Sent Events
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate_stream", path = "/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = StreamResponse, (status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}), example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}), example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}), example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}), example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
) )
)] )]
#[instrument( #[instrument(
skip(infer), skip(infer),
@ -403,10 +434,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/metrics", path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String)) responses((status = 200, description = "Prometheus Metrics", body = String))
)] )]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
@ -432,35 +463,36 @@ pub async fn run(
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
generate, generate,
generate_stream, generate_stream,
metrics, metrics,
), ),
components( components(
schemas( schemas(
GenerateRequest, GenerateRequest,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
Details, BestOfSequence,
FinishReason, Details,
StreamResponse, FinishReason,
StreamDetails, StreamResponse,
ErrorResponse, StreamDetails,
) ErrorResponse,
), )
tags( ),
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") tags(
), (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
info( ),
title = "Text Generation Inference", info(
license( title = "Text Generation Inference",
name = "Apache 2.0", license(
url = "https://www.apache.org/licenses/LICENSE-2.0" name = "Apache 2.0",
) url = "https://www.apache.org/licenses/LICENSE-2.0"
) )
)
)] )]
struct ApiDoc; struct ApiDoc;