support seeding

This commit is contained in:
OlivierDehaene 2023-01-30 16:16:58 +01:00
parent 5ef1336997
commit 6d024e5708
4 changed files with 39 additions and 27 deletions

View File

@ -110,6 +110,7 @@ impl Infer {
let mut stream = self.generate_stream(request).await?; let mut stream = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new(); let mut result_tokens = Vec::new();
let mut result_generated_text = None; let mut result_generated_text = None;
let mut result_start = None; let mut result_start = None;
@ -119,17 +120,16 @@ impl Infer {
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response? { match response? {
// Add prefill tokens // Add prefill tokens
InferStreamResponse::Prefill(prefill_tokens) => { InferStreamResponse::Prefill(tokens) => {
// Create Token objects // Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster // We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens result_prefill = tokens
.ids .ids
.into_iter() .into_iter()
.zip(prefill_tokens.logprobs.into_iter()) .zip(tokens.logprobs.into_iter())
.zip(prefill_tokens.texts.into_iter()) .zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob)) .map(|((id, logprob), text)| Token(id, text, logprob))
.collect(); .collect();
result_tokens = prefill_tokens;
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Token(token) => result_tokens.push(token),
@ -154,6 +154,7 @@ impl Infer {
(result_generated_text, result_queued, result_start) (result_generated_text, result_queued, result_start)
{ {
Ok(InferResponse { Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens, tokens: result_tokens,
generated_text, generated_text,
queued, queued,
@ -333,9 +334,9 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) tokens: Vec<Token>, pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) seed: Option<u64>
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
} }

View File

@ -77,22 +77,24 @@ pub(crate) struct Details {
pub finish_reason: String, pub finish_reason: String,
pub generated_tokens: u32, pub generated_tokens: u32,
pub seed: Option<u64>, pub seed: Option<u64>,
pub tokens: Vec<Token>, #[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct GeneratedText { pub(crate) struct GenerateResponse {
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct StreamToken { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
pub end: bool,
pub finish_reason: Option<String>,
pub generated_text: Option<String>, pub generated_text: Option<String>,
pub details: Option<Details>,
} }
#[derive(Serialize)] #[derive(Serialize)]

View File

@ -1,7 +1,7 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken, Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, StreamResponse,
Validation, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
@ -77,8 +77,9 @@ async fn generate(
true => Some(Details { true => Some(Details {
finish_reason: response.generated_text.finish_reason, finish_reason: response.generated_text.finish_reason,
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
tokens: response.tokens, prefill: Some(response.prefill),
seed: response.seed, tokens: Some(response.tokens),
seed: response.generated_text.seed,
}), }),
false => None, false => None,
}; };
@ -119,11 +120,11 @@ async fn generate(
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", response.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
let response = vec![GeneratedText { let response = vec![GenerateResponse {
generated_text: response.generated_text.text, generated_text: response.generated_text.text,
details, details,
}]; }];
@ -152,6 +153,7 @@ async fn generate_stream(
// Inference // Inference
let mut end_reached = false; let mut end_reached = false;
let mut error = false; let mut error = false;
let details = req.0.parameters.details;
match infer.generate_stream(req.0).await { match infer.generate_stream(req.0).await {
Ok(mut response_stream) => { Ok(mut response_stream) => {
@ -164,12 +166,11 @@ async fn generate_stream(
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Token(token) => { InferStreamResponse::Token(token) => {
// StreamToken // StreamResponse
let stream_token = StreamToken { let stream_token = StreamResponse {
token, token,
end: end_reached,
finish_reason: None,
generated_text: None, generated_text: None,
details: None,
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(Event::default().json_data(stream_token).unwrap())
@ -181,6 +182,18 @@ async fn generate_stream(
start, start,
queued, queued,
} => { } => {
// Token details
let details = match details {
true => Some(Details {
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed,
}),
false => None,
};
// Timings // Timings
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
let validation_time = queued - start_time; let validation_time = queued - start_time;
@ -199,13 +212,12 @@ async fn generate_stream(
.record("time_per_token", format!("{:?}", time_per_token)); .record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamToken // StreamResponse
end_reached = true; end_reached = true;
let stream_token = StreamToken { let stream_token = StreamResponse {
token, token,
end: end_reached,
finish_reason: Some(generated_text.finish_reason),
generated_text: Some(generated_text.text), generated_text: Some(generated_text.text),
details
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(Event::default().json_data(stream_token).unwrap())

View File

@ -361,9 +361,6 @@ class CausalLM(Model):
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text output_text = request.inputs + generated_text
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason
)
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):