diff --git a/router/src/infer.rs b/router/src/infer.rs index 03b5efcf..4c4a7eb8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -110,6 +110,7 @@ impl Infer { let mut stream = self.generate_stream(request).await?; // Return values + let mut result_prefill = Vec::new(); let mut result_tokens = Vec::new(); let mut result_generated_text = None; let mut result_start = None; @@ -119,17 +120,16 @@ impl Infer { while let Some(response) = stream.next().await { match response? { // Add prefill tokens - InferStreamResponse::Prefill(prefill_tokens) => { + InferStreamResponse::Prefill(tokens) => { // Create Token objects // We do that here instead of in the Python code as Rust for loops are faster - let prefill_tokens = prefill_tokens + result_prefill = tokens .ids .into_iter() - .zip(prefill_tokens.logprobs.into_iter()) - .zip(prefill_tokens.texts.into_iter()) + .zip(tokens.logprobs.into_iter()) + .zip(tokens.texts.into_iter()) .map(|((id, logprob), text)| Token(id, text, logprob)) .collect(); - result_tokens = prefill_tokens; } // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), @@ -154,6 +154,7 @@ impl Infer { (result_generated_text, result_queued, result_start) { Ok(InferResponse { + prefill: result_prefill, tokens: result_tokens, generated_text, queued, @@ -333,9 +334,9 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, - pub(crate) seed: Option pub(crate) queued: Instant, pub(crate) start: Instant, } diff --git a/router/src/lib.rs b/router/src/lib.rs index 940f06d1..beab7138 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -77,22 +77,24 @@ pub(crate) struct Details { pub finish_reason: String, pub generated_tokens: u32, pub seed: Option, - pub tokens: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub prefill: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option>, } #[derive(Serialize)] -pub(crate) struct GeneratedText { +pub(crate) struct GenerateResponse { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, } #[derive(Serialize)] -pub(crate) struct StreamToken { +pub(crate) struct StreamResponse { pub token: Token, - pub end: bool, - pub finish_reason: Option, pub generated_text: Option, + pub details: Option
, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index c9fb4f36..7da56a36 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,7 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken, + Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, StreamResponse, Validation, }; use axum::extract::Extension; @@ -77,8 +77,9 @@ async fn generate( true => Some(Details { finish_reason: response.generated_text.finish_reason, generated_tokens: response.generated_text.generated_tokens, - tokens: response.tokens, - seed: response.seed, + prefill: Some(response.prefill), + tokens: Some(response.tokens), + seed: response.generated_text.seed, }), false => None, }; @@ -119,11 +120,11 @@ async fn generate( span.record("queue_time", format!("{:?}", queue_time)); span.record("inference_time", format!("{:?}", inference_time)); span.record("time_per_token", format!("{:?}", time_per_token)); - span.record("seed", format!("{:?}", response.seed)); + span.record("seed", format!("{:?}", response.generated_text.seed)); tracing::info!("Output: {}", response.generated_text.text); // Send response - let response = vec![GeneratedText { + let response = vec![GenerateResponse { generated_text: response.generated_text.text, details, }]; @@ -152,6 +153,7 @@ async fn generate_stream( // Inference let mut end_reached = false; let mut error = false; + let details = req.0.parameters.details; match infer.generate_stream(req.0).await { Ok(mut response_stream) => { @@ -164,12 +166,11 @@ async fn generate_stream( InferStreamResponse::Prefill(_) => {} // Yield event for every new token InferStreamResponse::Token(token) => { - // StreamToken - let stream_token = StreamToken { + // StreamResponse + let stream_token = StreamResponse { token, - end: end_reached, - finish_reason: None, generated_text: None, + details: None, }; yield Ok(Event::default().json_data(stream_token).unwrap()) @@ -181,6 +182,18 @@ async fn generate_stream( start, queued, } => { + // Token details + let details = match details { + true => Some(Details { + finish_reason: generated_text.finish_reason, + generated_tokens: generated_text.generated_tokens, + prefill: None, + tokens: None, + seed: generated_text.seed, + }), + false => None, + }; + // Timings let total_time = start_time.elapsed(); let validation_time = queued - start_time; @@ -199,13 +212,12 @@ async fn generate_stream( .record("time_per_token", format!("{:?}", time_per_token)); tracing::info!(parent: &span, "Output: {}", generated_text.text); - // StreamToken + // StreamResponse end_reached = true; - let stream_token = StreamToken { + let stream_token = StreamResponse { token, - end: end_reached, - finish_reason: Some(generated_text.finish_reason), generated_text: Some(generated_text.text), + details }; yield Ok(Event::default().json_data(stream_token).unwrap()) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index e34e5a38..d2a965b5 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -361,9 +361,6 @@ class CausalLM(Model): all_input_ids[-stopping_criteria.current_tokens :, 0] ) output_text = request.inputs + generated_text - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason - ) # Get seed if isinstance(next_token_chooser.choice, Sampling):