mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
Non breaking router.
This commit is contained in:
parent
a478b276eb
commit
b0cb4fa9d0
@ -515,6 +515,7 @@ fn send_responses(
|
|||||||
|
|
||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
|
tracing::info!("Generation: {:?}", generation);
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry
|
||||||
@ -559,6 +560,11 @@ fn send_responses(
|
|||||||
);
|
);
|
||||||
top_tokens.push(local_top_tokens);
|
top_tokens.push(local_top_tokens);
|
||||||
}
|
}
|
||||||
|
// Force top_tokens to be the same size as tokens, both are going to be
|
||||||
|
// zipped later
|
||||||
|
if top_tokens.len() != tokens.len(){
|
||||||
|
top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect();
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
|
@ -279,10 +279,9 @@ pub(crate) struct StreamDetails {
|
|||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub tokens: Vec<Token>,
|
pub token: Token,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub top_tokens: Vec<Vec<Token>>,
|
pub top_tokens: Vec<Token>,
|
||||||
pub text: String,
|
|
||||||
#[schema(nullable = true, default = "null", example = "test")]
|
#[schema(nullable = true, default = "null", example = "test")]
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
#[schema(nullable = true, default = "null")]
|
#[schema(nullable = true, default = "null")]
|
||||||
|
@ -391,38 +391,28 @@ async fn generate_stream(
|
|||||||
tokens,
|
tokens,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
tracing::debug!(parent: &span, "Tokens: {:?}", tokens);
|
|
||||||
|
|
||||||
|
for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) {
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
tokens,
|
token,
|
||||||
text,
|
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Yield event for last token and compute timings
|
// Yield event for last token and compute timings
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
tokens,
|
tokens,
|
||||||
text,
|
|
||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
|
||||||
true => Some(StreamDetails {
|
|
||||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
|
||||||
generated_tokens: generated_text.generated_tokens,
|
|
||||||
seed: generated_text.seed,
|
|
||||||
}),
|
|
||||||
false => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Timings
|
// Timings
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
let validation_time = queued - start_time;
|
let validation_time = queued - start_time;
|
||||||
@ -450,23 +440,45 @@ async fn generate_stream(
|
|||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
|
|
||||||
let mut output_text = generated_text.text;
|
let n_tokens = tokens.len();
|
||||||
if let Some(prompt) = add_prompt {
|
for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() {
|
||||||
output_text = prompt + &output_text;
|
// StreamResponse
|
||||||
|
let stream_token = if i < n_tokens - 1 {
|
||||||
|
StreamResponse {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}else{
|
||||||
|
let details = match details {
|
||||||
|
true => Some(StreamDetails {
|
||||||
|
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||||
|
generated_tokens: generated_text.generated_tokens,
|
||||||
|
seed: generated_text.seed,
|
||||||
|
}),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
let output_text = if let Some(prompt) = &add_prompt {
|
||||||
|
prompt.to_owned() + &generated_text.text
|
||||||
|
}else{
|
||||||
|
generated_text.text.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
tracing::debug!(parent: &span, "Output: {}", output_text);
|
tracing::debug!(parent: &span, "Output: {}", output_text);
|
||||||
tracing::info!(parent: &span, "Success");
|
tracing::info!(parent: &span, "Success");
|
||||||
|
|
||||||
let stream_token = StreamResponse {
|
StreamResponse {
|
||||||
tokens,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
text
|
|
||||||
generated_text: Some(output_text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user