diff --git a/router/src/infer.rs b/router/src/infer.rs index 050005f1..5972e67b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -515,6 +515,7 @@ fn send_responses( let mut stopped = false; + tracing::info!("Generation: {:?}", generation); if let Some(prefill_tokens) = generation.prefill_tokens { // Send message entry @@ -559,6 +560,11 @@ fn send_responses( ); top_tokens.push(local_top_tokens); } + // Force top_tokens to be the same size as tokens, both are going to be + // zipped later + if top_tokens.len() != tokens.len(){ + top_tokens = (0..tokens.len()).map(|_| Vec::new()).collect(); + } if let Some(generated_text) = generation.generated_text { // Generation has ended diff --git a/router/src/lib.rs b/router/src/lib.rs index cbc0b478..b547dc15 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -279,10 +279,9 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { - pub tokens: Vec, + pub token: Token, #[serde(skip_serializing_if = "Vec::is_empty")] - pub top_tokens: Vec>, - pub text: String, + pub top_tokens: Vec, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/server.rs b/router/src/server.rs index a26dafd1..789b47e4 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -391,38 +391,28 @@ async fn generate_stream( tokens, top_tokens, } => { - tracing::debug!(parent: &span, "Tokens: {:?}", tokens); - // StreamResponse - let stream_token = StreamResponse { - tokens, - text, - top_tokens, - generated_text: None, - details: None, - }; + for (token, top_tokens) in tokens.into_iter().zip(top_tokens.into_iter()) { + // StreamResponse + let stream_token = StreamResponse { + token, + top_tokens, + generated_text: None, + details: None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()) + yield Ok(Event::default().json_data(stream_token).unwrap()); + } } // Yield event for last token and compute timings InferStreamResponse::End { tokens, - text, generated_text, start, queued, top_tokens, } => { // Token details - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - // Timings let total_time = start_time.elapsed(); let validation_time = queued - start_time; @@ -450,23 +440,45 @@ async fn generate_stream( // StreamResponse end_reached = true; - let mut output_text = generated_text.text; - if let Some(prompt) = add_prompt { - output_text = prompt + &output_text; + let n_tokens = tokens.len(); + for (i, (token, top_tokens)) in tokens.into_iter().zip(top_tokens.into_iter()).enumerate() { + // StreamResponse + let stream_token = if i < n_tokens - 1 { + StreamResponse { + token, + top_tokens, + generated_text: None, + details: None, + } + + }else{ + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; + let output_text = if let Some(prompt) = &add_prompt { + prompt.to_owned() + &generated_text.text + }else{ + generated_text.text.to_owned() + }; + + tracing::debug!(parent: &span, "Output: {}", output_text); + tracing::info!(parent: &span, "Success"); + + StreamResponse { + token, + top_tokens, + generated_text: Some(output_text), + details + } + }; + yield Ok(Event::default().json_data(stream_token).unwrap()); } - tracing::debug!(parent: &span, "Output: {}", output_text); - tracing::info!(parent: &span, "Success"); - - let stream_token = StreamResponse { - tokens, - top_tokens, - text - generated_text: Some(output_text), - details - }; - - yield Ok(Event::default().json_data(stream_token).unwrap()); break; } }