fix: prefer index on StreamResponse

This commit is contained in:
drbh 2024-01-09 11:59:11 -05:00
parent f82ff3f64a
commit 446b3b6af7
2 changed files with 11 additions and 11 deletions

View File

@ -469,6 +469,7 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub index: u32,
pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,

View File

@ -338,7 +338,7 @@ async fn generate_stream(
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let on_message_callback = |_: u32, stream_token: StreamResponse| {
let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
@ -352,7 +352,7 @@ async fn generate_stream(
async fn generate_stream_internal(
infer: Infer,
Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(u32, StreamResponse) -> Event,
on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now();
@ -414,12 +414,13 @@ async fn generate_stream_internal(
// StreamResponse
let stream_token = StreamResponse {
index,
token,
top_tokens,
generated_text: None,
details: None,
};
let event = on_message_callback(index, stream_token);
let event = on_message_callback(stream_token);
yield Ok(event);
}
// Yield event for last token and compute timings
@ -476,6 +477,7 @@ async fn generate_stream_internal(
tracing::info!(parent: &span, "Success");
let stream_token = StreamResponse {
index,
token,
top_tokens,
generated_text: Some(output_text),
@ -483,7 +485,7 @@ async fn generate_stream_internal(
};
let event = on_message_callback(index, stream_token);
let event = on_message_callback(stream_token);
yield Ok(event);
break;
}
@ -607,13 +609,10 @@ async fn chat_completions(
// switch on stream
if stream {
let model_id = info.model_id.clone();
let system_fingerprint = format!(
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
);
let system_fingerprint =
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// pass this callback to the stream generation and build the required event structure
let on_message_callback = move |index: u32, stream_token: StreamResponse| {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
@ -627,7 +626,7 @@ async fn chat_completions(
system_fingerprint.clone(),
stream_token.token.text,
current_time,
index,
stream_token.index,
None,
None,
))