Non breaking router.

This commit is contained in:
Nicolas Patry 2023-11-30 08:27:25 +00:00
parent a478b276eb
commit b0cb4fa9d0
3 changed files with 55 additions and 38 deletions

View File

@ -515,6 +515,7 @@ fn send_responses(
let mut stopped = false;
tracing::info!("Generation: {:?}", generation);
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry
@ -559,6 +560,11 @@ fn send_responses(
);
top_tokens.push(local_top_tokens);
}
// Force top_tokens to be the same size as tokens, both are going to be
// zipped later
if top_tokens.len() != tokens.len(){
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended

View File

@ -279,10 +279,9 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub tokens: Vec<Token>,
pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
pub text: String,
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -391,38 +391,28 @@ async fn generate_stream(
tokens,
top_tokens,
} => {
tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
// StreamResponse
let stream_token = StreamResponse {
tokens,
text,
top_tokens,
generated_text: None,
details: None,
};
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap())
yield Ok(Event::default().json_data(stream_token).unwrap());
}
}
// Yield event for last token and compute timings
InferStreamResponse::End {
tokens,
text,
generated_text,
start,
queued,
top_tokens,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
@ -450,23 +440,45 @@ async fn generate_stream(
// StreamResponse
end_reached = true;
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
let n_tokens = tokens.len();
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
// StreamResponse
let stream_token = if i < n_tokens - 1 {
StreamResponse {
token,
top_tokens,
generated_text: None,
details: None,
}
}else{
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
let output_text = if let Some(prompt) = &add_prompt {
prompt.to_owned() + &generated_text.text
}else{
generated_text.text.to_owned()
};
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
StreamResponse {
token,
top_tokens,
generated_text: Some(output_text),
details
}
};
yield Ok(Event::default().json_data(stream_token).unwrap());
}
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse {
tokens,
top_tokens,
text
generated_text: Some(output_text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap());
break;
}
}