Non breaking router.

This commit is contained in:
Nicolas Patry 2023-11-30 08:27:25 +00:00
parent a478b276eb
commit b0cb4fa9d0
3 changed files with 55 additions and 38 deletions

View File

@ -515,6 +515,7 @@ fn send_responses(
let mut stopped = false; let mut stopped = false;
tracing::info!("Generation: {:?}", generation);
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry entry
@ -559,6 +560,11 @@ fn send_responses(
); );
top_tokens.push(local_top_tokens); top_tokens.push(local_top_tokens);
} }
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len(){
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Generation has ended // Generation has ended

View File

@ -279,10 +279,9 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub tokens: Vec<Token>, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>, pub top_tokens: Vec<Token>,
pub text: String,
#[schema(nullable = true, default = "null", example = "test")] #[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")] #[schema(nullable = true, default = "null")]

View File

@ -391,38 +391,28 @@ async fn generate_stream(
tokens, tokens,
top_tokens, top_tokens,
} => { } => {
tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
tokens, token,
text,
top_tokens, top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(Event::default().json_data(stream_token).unwrap());
}
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
tokens, tokens,
text,
generated_text, generated_text,
start, start,
queued, queued,
top_tokens, top_tokens,
} => { } => {
// Token details // Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings // Timings
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
let validation_time = queued - start_time; let validation_time = queued - start_time;
@ -450,23 +440,45 @@ async fn generate_stream(
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
let mut output_text = generated_text.text; let n_tokens = tokens.len();
if let Some(prompt) = add_prompt { for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
output_text = prompt + &output_text; // StreamResponse
let stream_token = if i < n_tokens - 1 {
StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
} }
}else{
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
let output_text = if let Some(prompt) = &add_prompt {
prompt.to_owned() + &generated_text.text
}else{
generated_text.text.to_owned()
};
tracing::debug!(parent: &span, "Output: {}", output_text); tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success"); tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse { StreamResponse {
tokens, token,
top_tokens, top_tokens,
text
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()); yield Ok(Event::default().json_data(stream_token).unwrap());
}
break; break;
} }
} }