mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
support seeding
This commit is contained in:
parent
5ef1336997
commit
6d024e5708
@ -110,6 +110,7 @@ impl Infer {
|
|||||||
let mut stream = self.generate_stream(request).await?;
|
let mut stream = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
|
let mut result_prefill = Vec::new();
|
||||||
let mut result_tokens = Vec::new();
|
let mut result_tokens = Vec::new();
|
||||||
let mut result_generated_text = None;
|
let mut result_generated_text = None;
|
||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
@ -119,17 +120,16 @@ impl Infer {
|
|||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response? {
|
match response? {
|
||||||
// Add prefill tokens
|
// Add prefill tokens
|
||||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
InferStreamResponse::Prefill(tokens) => {
|
||||||
// Create Token objects
|
// Create Token objects
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
let prefill_tokens = prefill_tokens
|
result_prefill = tokens
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(prefill_tokens.logprobs.into_iter())
|
.zip(tokens.logprobs.into_iter())
|
||||||
.zip(prefill_tokens.texts.into_iter())
|
.zip(tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||||
.collect();
|
.collect();
|
||||||
result_tokens = prefill_tokens;
|
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||||
@ -154,6 +154,7 @@ impl Infer {
|
|||||||
(result_generated_text, result_queued, result_start)
|
(result_generated_text, result_queued, result_start)
|
||||||
{
|
{
|
||||||
Ok(InferResponse {
|
Ok(InferResponse {
|
||||||
|
prefill: result_prefill,
|
||||||
tokens: result_tokens,
|
tokens: result_tokens,
|
||||||
generated_text,
|
generated_text,
|
||||||
queued,
|
queued,
|
||||||
@ -333,9 +334,9 @@ pub(crate) enum InferStreamResponse {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
|
pub(crate) prefill: Vec<Token>,
|
||||||
pub(crate) tokens: Vec<Token>,
|
pub(crate) tokens: Vec<Token>,
|
||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
pub(crate) seed: Option<u64>
|
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
}
|
}
|
||||||
|
@ -77,22 +77,24 @@ pub(crate) struct Details {
|
|||||||
pub finish_reason: String,
|
pub finish_reason: String,
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub tokens: Vec<Token>,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prefill: Option<Vec<Token>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tokens: Option<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct GeneratedText {
|
pub(crate) struct GenerateResponse {
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(crate) struct StreamToken {
|
pub(crate) struct StreamResponse {
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
pub end: bool,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken,
|
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, StreamResponse,
|
||||||
Validation,
|
Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
@ -77,8 +77,9 @@ async fn generate(
|
|||||||
true => Some(Details {
|
true => Some(Details {
|
||||||
finish_reason: response.generated_text.finish_reason,
|
finish_reason: response.generated_text.finish_reason,
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
tokens: response.tokens,
|
prefill: Some(response.prefill),
|
||||||
seed: response.seed,
|
tokens: Some(response.tokens),
|
||||||
|
seed: response.generated_text.seed,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
@ -119,11 +120,11 @@ async fn generate(
|
|||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{:?}", queue_time));
|
||||||
span.record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{:?}", inference_time));
|
||||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
span.record("seed", format!("{:?}", response.seed));
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = vec![GeneratedText {
|
let response = vec![GenerateResponse {
|
||||||
generated_text: response.generated_text.text,
|
generated_text: response.generated_text.text,
|
||||||
details,
|
details,
|
||||||
}];
|
}];
|
||||||
@ -152,6 +153,7 @@ async fn generate_stream(
|
|||||||
// Inference
|
// Inference
|
||||||
let mut end_reached = false;
|
let mut end_reached = false;
|
||||||
let mut error = false;
|
let mut error = false;
|
||||||
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
match infer.generate_stream(req.0).await {
|
match infer.generate_stream(req.0).await {
|
||||||
Ok(mut response_stream) => {
|
Ok(mut response_stream) => {
|
||||||
@ -164,12 +166,11 @@ async fn generate_stream(
|
|||||||
InferStreamResponse::Prefill(_) => {}
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// Yield event for every new token
|
// Yield event for every new token
|
||||||
InferStreamResponse::Token(token) => {
|
InferStreamResponse::Token(token) => {
|
||||||
// StreamToken
|
// StreamResponse
|
||||||
let stream_token = StreamToken {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
end: end_reached,
|
|
||||||
finish_reason: None,
|
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
@ -181,6 +182,18 @@ async fn generate_stream(
|
|||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
} => {
|
} => {
|
||||||
|
// Token details
|
||||||
|
let details = match details {
|
||||||
|
true => Some(Details {
|
||||||
|
finish_reason: generated_text.finish_reason,
|
||||||
|
generated_tokens: generated_text.generated_tokens,
|
||||||
|
prefill: None,
|
||||||
|
tokens: None,
|
||||||
|
seed: generated_text.seed,
|
||||||
|
}),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
let validation_time = queued - start_time;
|
let validation_time = queued - start_time;
|
||||||
@ -199,13 +212,12 @@ async fn generate_stream(
|
|||||||
.record("time_per_token", format!("{:?}", time_per_token));
|
.record("time_per_token", format!("{:?}", time_per_token));
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
// StreamToken
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
let stream_token = StreamToken {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
end: end_reached,
|
|
||||||
finish_reason: Some(generated_text.finish_reason),
|
|
||||||
generated_text: Some(generated_text.text),
|
generated_text: Some(generated_text.text),
|
||||||
|
details
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
|
@ -361,9 +361,6 @@ class CausalLM(Model):
|
|||||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||||
)
|
)
|
||||||
output_text = request.inputs + generated_text
|
output_text = request.inputs + generated_text
|
||||||
generated_text = GeneratedText(
|
|
||||||
output_text, stopping_criteria.current_tokens, reason
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
Loading…
Reference in New Issue
Block a user