support seeding

This commit is contained in:
OlivierDehaene 2023-01-30 16:16:58 +01:00
parent 5ef1336997
commit 6d024e5708
4 changed files with 39 additions and 27 deletions

View File

@ -110,6 +110,7 @@ impl Infer {
let mut stream = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
@ -119,17 +120,16 @@ impl Infer {
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => {
InferStreamResponse::Prefill(tokens) => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
result_prefill = tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter())
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob))
.collect();
result_tokens = prefill_tokens;
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
@ -154,6 +154,7 @@ impl Infer {
(result_generated_text, result_queued, result_start)
{
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
generated_text,
queued,
@ -333,9 +334,9 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) seed: Option<u64>
pub(crate) queued: Instant,
pub(crate) start: Instant,
}

View File

@ -77,22 +77,24 @@ pub(crate) struct Details {
pub finish_reason: String,
pub generated_tokens: u32,
pub seed: Option<u64>,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>,
}
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub(crate) struct GenerateResponse {
pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>,
}
#[derive(Serialize)]
pub(crate) struct StreamToken {
pub(crate) struct StreamResponse {
pub token: Token,
pub end: bool,
pub finish_reason: Option<String>,
pub generated_text: Option<String>,
pub details: Option<Details>,
}
#[derive(Serialize)]

View File

@ -1,7 +1,7 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken,
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, StreamResponse,
Validation,
};
use axum::extract::Extension;
@ -77,8 +77,9 @@ async fn generate(
true => Some(Details {
finish_reason: response.generated_text.finish_reason,
generated_tokens: response.generated_text.generated_tokens,
tokens: response.tokens,
seed: response.seed,
prefill: Some(response.prefill),
tokens: Some(response.tokens),
seed: response.generated_text.seed,
}),
false => None,
};
@ -119,11 +120,11 @@ async fn generate(
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", response.seed));
span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text);
// Send response
let response = vec![GeneratedText {
let response = vec![GenerateResponse {
generated_text: response.generated_text.text,
details,
}];
@ -152,6 +153,7 @@ async fn generate_stream(
// Inference
let mut end_reached = false;
let mut error = false;
let details = req.0.parameters.details;
match infer.generate_stream(req.0).await {
Ok(mut response_stream) => {
@ -164,12 +166,11 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
// StreamToken
let stream_token = StreamToken {
// StreamResponse
let stream_token = StreamResponse {
token,
end: end_reached,
finish_reason: None,
generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap())
@ -181,6 +182,18 @@ async fn generate_stream(
start,
queued,
} => {
// Token details
let details = match details {
true => Some(Details {
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
@ -199,13 +212,12 @@ async fn generate_stream(
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamToken
// StreamResponse
end_reached = true;
let stream_token = StreamToken {
let stream_token = StreamResponse {
token,
end: end_reached,
finish_reason: Some(generated_text.finish_reason),
generated_text: Some(generated_text.text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap())

View File

@ -361,9 +361,6 @@ class CausalLM(Model):
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
output_text = request.inputs + generated_text
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):