fix: prefer index on StreamResponse

This commit is contained in:
drbh 2024-01-09 11:59:11 -05:00
parent f82ff3f64a
commit 446b3b6af7
2 changed files with 11 additions and 11 deletions

View File

@ -469,6 +469,7 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub index: u32,
pub token: Token, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>, pub top_tokens: Vec<Token>,

View File

@ -338,7 +338,7 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let on_message_callback = |_: u32, stream_token: StreamResponse| { let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
@ -352,7 +352,7 @@ async fn generate_stream(
async fn generate_stream_internal( async fn generate_stream_internal(
infer: Infer, infer: Infer,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(u32, StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
@ -414,12 +414,13 @@ async fn generate_stream_internal(
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
index,
token, token,
top_tokens, top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
let event = on_message_callback(index, stream_token); let event = on_message_callback(stream_token);
yield Ok(event); yield Ok(event);
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
@ -476,6 +477,7 @@ async fn generate_stream_internal(
tracing::info!(parent: &span, "Success"); tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse { let stream_token = StreamResponse {
index,
token, token,
top_tokens, top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
@ -483,7 +485,7 @@ async fn generate_stream_internal(
}; };
let event = on_message_callback(index, stream_token); let event = on_message_callback(stream_token);
yield Ok(event); yield Ok(event);
break; break;
} }
@ -607,13 +609,10 @@ async fn chat_completions(
// switch on stream // switch on stream
if stream { if stream {
let model_id = info.model_id.clone(); let model_id = info.model_id.clone();
let system_fingerprint = format!( let system_fingerprint =
"{}-{}", format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
info.version,
info.docker_label.unwrap_or("native")
);
// pass this callback to the stream generation and build the required event structure // pass this callback to the stream generation and build the required event structure
let on_message_callback = move |index: u32, stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
@ -627,7 +626,7 @@ async fn chat_completions(
system_fingerprint.clone(), system_fingerprint.clone(),
stream_token.token.text, stream_token.token.text,
current_time, current_time,
index, stream_token.index,
None, None,
None, None,
)) ))