mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: prefer index on StreamResponse
This commit is contained in:
parent
f82ff3f64a
commit
446b3b6af7
@ -469,6 +469,7 @@ pub(crate) struct StreamDetails {
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub index: u32,
|
||||
pub token: Token,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Token>,
|
||||
|
@ -338,7 +338,7 @@ async fn generate_stream(
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let on_message_callback = |_: u32, stream_token: StreamResponse| {
|
||||
let on_message_callback = |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
event.json_data(stream_token).unwrap()
|
||||
};
|
||||
@ -352,7 +352,7 @@ async fn generate_stream(
|
||||
async fn generate_stream_internal(
|
||||
infer: Infer,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
on_message_callback: impl Fn(u32, StreamResponse) -> Event,
|
||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
@ -414,12 +414,13 @@ async fn generate_stream_internal(
|
||||
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
index,
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
let event = on_message_callback(index, stream_token);
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
}
|
||||
// Yield event for last token and compute timings
|
||||
@ -476,6 +477,7 @@ async fn generate_stream_internal(
|
||||
tracing::info!(parent: &span, "Success");
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
index,
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
@ -483,7 +485,7 @@ async fn generate_stream_internal(
|
||||
};
|
||||
|
||||
|
||||
let event = on_message_callback(index, stream_token);
|
||||
let event = on_message_callback(stream_token);
|
||||
yield Ok(event);
|
||||
break;
|
||||
}
|
||||
@ -607,13 +609,10 @@ async fn chat_completions(
|
||||
// switch on stream
|
||||
if stream {
|
||||
let model_id = info.model_id.clone();
|
||||
let system_fingerprint = format!(
|
||||
"{}-{}",
|
||||
info.version,
|
||||
info.docker_label.unwrap_or("native")
|
||||
);
|
||||
let system_fingerprint =
|
||||
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||
// pass this callback to the stream generation and build the required event structure
|
||||
let on_message_callback = move |index: u32, stream_token: StreamResponse| {
|
||||
let on_message_callback = move |stream_token: StreamResponse| {
|
||||
let event = Event::default();
|
||||
|
||||
let current_time = std::time::SystemTime::now()
|
||||
@ -627,7 +626,7 @@ async fn chat_completions(
|
||||
system_fingerprint.clone(),
|
||||
stream_token.token.text,
|
||||
current_time,
|
||||
index,
|
||||
stream_token.index,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
|
Loading…
Reference in New Issue
Block a user