mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
add best of sequences to details
This commit is contained in:
parent
9624d4060f
commit
9f4f2fc8e3
@ -185,28 +185,35 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
best_of: usize,
|
best_of: usize,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
// validate best_of parameter separately
|
// validate best_of parameter separately
|
||||||
let best_of = self.validation.validate_best_of(best_of)?;
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
// create multiple generate requests
|
// create multiple generate requests
|
||||||
let infer_responses: Vec<InferResponse> =
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||||
|
|
||||||
// get the sequence with the highest log probability per token
|
// get the sequence with the highest log probability per token
|
||||||
|
let mut max_index = 0;
|
||||||
let mut max_logprob: f32 = f32::MIN;
|
let mut max_logprob: f32 = f32::MIN;
|
||||||
let mut best_response = None;
|
|
||||||
for response in infer_responses {
|
for (i, response) in infer_responses.iter().enumerate() {
|
||||||
// sum logprobs of the generated tokens
|
// mean logprobs of the generated tokens
|
||||||
let sequence_logprob = response.tokens.iter().map(|token| token.logprob).sum();
|
let sequence_logprob = response
|
||||||
|
.tokens
|
||||||
|
.iter()
|
||||||
|
.map(|token| token.logprob)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ response.tokens.len() as f32;
|
||||||
|
|
||||||
// set best sequence
|
// set best sequence
|
||||||
if sequence_logprob > max_logprob {
|
if sequence_logprob > max_logprob {
|
||||||
|
max_index = i;
|
||||||
max_logprob = sequence_logprob;
|
max_logprob = sequence_logprob;
|
||||||
best_response = Some(response);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(best_response.expect("best_response is None. This is a bug."))
|
let best_response = infer_responses.remove(max_index);
|
||||||
|
Ok((best_response, infer_responses))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,6 +168,20 @@ pub(crate) enum FinishReason {
|
|||||||
StopSequence,
|
StopSequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct BestOfSequence {
|
||||||
|
#[schema(example = "test")]
|
||||||
|
pub generated_text: String,
|
||||||
|
#[schema(example = "length")]
|
||||||
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
#[schema(example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
pub prefill: Vec<PrefillToken>,
|
||||||
|
pub tokens: Vec<Token>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
@ -178,6 +192,8 @@ pub(crate) struct Details {
|
|||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Vec<PrefillToken>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
|
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
|
||||||
Validation,
|
StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
@ -87,21 +87,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate",
|
path = "/generate",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"})),
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"})),
|
example = json ! ({"error": "Input validation error"})),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"})),
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
@ -130,20 +130,51 @@ async fn generate(
|
|||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = match req.0.parameters.best_of {
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
Some(best_of) if best_of > 1 => infer.generate_best_of(req.0, best_of).await?,
|
Some(best_of) if best_of > 1 => {
|
||||||
_ => infer.generate(req.0).await?,
|
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||||
|
(response, Some(best_of_responses))
|
||||||
|
}
|
||||||
|
_ => (infer.generate(req.0).await?, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
// convert best_of_responses
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
|
||||||
prefill: response.prefill,
|
responses
|
||||||
tokens: response.tokens,
|
.into_iter()
|
||||||
seed: response.generated_text.seed,
|
.map(|response: InferResponse| {
|
||||||
}),
|
// Add prompt if return_full_text
|
||||||
|
let mut output_text = response.generated_text.text;
|
||||||
|
if let Some(prompt) = &add_prompt {
|
||||||
|
output_text = prompt.clone() + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
BestOfSequence {
|
||||||
|
generated_text: output_text,
|
||||||
|
finish_reason: FinishReason::from(
|
||||||
|
response.generated_text.finish_reason,
|
||||||
|
),
|
||||||
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
|
prefill: response.prefill,
|
||||||
|
tokens: response.tokens,
|
||||||
|
seed: response.generated_text.seed,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
Some(Details {
|
||||||
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
|
prefill: response.prefill,
|
||||||
|
tokens: response.tokens,
|
||||||
|
seed: response.generated_text.seed,
|
||||||
|
best_of_sequences,
|
||||||
|
})
|
||||||
|
}
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -222,26 +253,26 @@ async fn generate(
|
|||||||
|
|
||||||
/// Generate a stream of token using Server-Sent Events
|
/// Generate a stream of token using Server-Sent Events
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate_stream",
|
path = "/generate_stream",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"}),
|
example = json ! ({"error": "Request failed during generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"}),
|
example = json ! ({"error": "Model is overloaded"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"}),
|
example = json ! ({"error": "Input validation error"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"}),
|
example = json ! ({"error": "Incomplete generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
@ -403,10 +434,10 @@ async fn generate_stream(
|
|||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/metrics",
|
path = "/metrics",
|
||||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
)]
|
)]
|
||||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
@ -432,35 +463,36 @@ pub async fn run(
|
|||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
Details,
|
BestOfSequence,
|
||||||
FinishReason,
|
Details,
|
||||||
StreamResponse,
|
FinishReason,
|
||||||
StreamDetails,
|
StreamResponse,
|
||||||
ErrorResponse,
|
StreamDetails,
|
||||||
)
|
ErrorResponse,
|
||||||
),
|
)
|
||||||
tags(
|
),
|
||||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
tags(
|
||||||
),
|
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||||
info(
|
),
|
||||||
title = "Text Generation Inference",
|
info(
|
||||||
license(
|
title = "Text Generation Inference",
|
||||||
name = "Apache 2.0",
|
license(
|
||||||
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
name = "Apache 2.0",
|
||||||
)
|
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
)
|
)
|
||||||
|
)
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user