From 046801278ea18aee0dd73c50a35682931a421a92 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 30 Jan 2023 10:55:54 +0100 Subject: [PATCH] Improved version --- launcher/Cargo.toml | 2 +- router/client/src/lib.rs | 2 +- router/src/db.rs | 3 + router/src/infer.rs | 82 ++++++++++++--------- router/src/lib.rs | 8 +++ router/src/server.rs | 152 ++++++++++++++++++++++++++------------- 6 files changed, 166 insertions(+), 83 deletions(-) diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 21d5d3ee..58df28d9 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] } [dev-dependencies] float_eq = "1.0.1" reqwest = { version = "0.11.13", features = ["blocking", "json"] } -serde = "1.0.150" +serde = { version = "1.0.150", features = ["derive"] } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index ec90103b..d0724625 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -8,7 +8,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ Batch, GeneratedText, Generation, NextTokenChooserParameters, Request, - StoppingCriteriaParameters, + StoppingCriteriaParameters, PrefillTokens }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/db.rs b/router/src/db.rs index 9d1fd6f2..f0a62d65 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -10,6 +10,7 @@ use text_generation_client::{ Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::OwnedSemaphorePermit; use tokio::time::Instant; /// Database entry @@ -25,6 +26,8 @@ pub(crate) struct Entry { pub time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, + /// Permit + pub _permit: OwnedSemaphorePermit, } /// Request Database diff --git a/router/src/infer.rs b/router/src/infer.rs index 62cd0248..065313d6 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -5,7 +5,7 @@ use crate::{Db, Entry, Token}; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient}; +use text_generation_client::{Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient}; use thiserror::Error; use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; use tokio::time::Instant; @@ -22,12 +22,12 @@ pub struct Infer { db: Db, /// Shared state shared: Arc, + /// Inference limit + limit_concurrent_requests: Arc, } /// Infer shared state struct Shared { - /// Inference limit - limit_concurrent_requests: Semaphore, /// Batching background Tokio task notifier batching_task: Notify, } @@ -43,7 +43,6 @@ impl Infer { // Infer shared state let db = Db::new(); let shared = Arc::new(Shared { - limit_concurrent_requests: Semaphore::new(max_concurrent_requests), batching_task: Notify::new(), }); @@ -56,10 +55,14 @@ impl Infer { shared.clone(), )); + // Inference limit with a semaphore + let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + Self { validation, db, shared, + limit_concurrent_requests: semaphore, } } @@ -69,7 +72,8 @@ impl Infer { request: GenerateRequest, ) -> Result>, InferError> { // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = self.shared.limit_concurrent_requests.try_acquire()?; + // This permit will live as long as Entry + let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; // Validate request let (input_length, validated_request) = self.validation.validate(request).await?; @@ -84,6 +88,7 @@ impl Infer { input_length, time: Instant::now(), batch_time: None, + _permit: permit }); // Notify the background task that we have a new entry in the database that needs @@ -113,30 +118,48 @@ impl Infer { match response? { // Add prefill tokens InferStreamResponse::Prefill(prefill_tokens) => { - result_tokens.extend(prefill_tokens) + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + let prefill_tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs.into_iter()) + .zip(prefill_tokens.texts.into_iter()) + .map(|((id, logprob), text)| Token(id, text, logprob)) + .collect(); + result_tokens = prefill_tokens; } // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), // Final message // Set return values InferStreamResponse::End { + token, generated_text, start, queued, } => { + result_tokens.push(token); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) } } } - // Unwrap is safe here - Ok(InferResponse { - tokens: result_tokens, - generated_text: result_generated_text.unwrap(), - queued: result_queued.unwrap(), - start: result_start.unwrap(), - }) + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + tokens: result_tokens, + generated_text, + queued, + start, + }) + } else { + Err(InferError::IncompleteGeneration) + } } } @@ -210,7 +233,7 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the responses to Infer async fn wrap_future( - future: impl Future, Option), ClientError>>, + future: impl Future, Option), ClientError>>, entries: &mut IntMap, ) -> Option { match future.await { @@ -247,20 +270,11 @@ fn send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap, entries: &mut IntMap, entries: &mut IntMap), + Prefill(PrefillTokens), // Intermediate messages Token(Token), // Last message End { + token: Token, generated_text: GeneratedText, start: Instant, queued: Instant, @@ -330,4 +346,6 @@ pub enum InferError { Overloaded(#[from] TryAcquireError), #[error("Input validation error: {0}")] ValidationError(#[from] ValidationError), + #[error("Incomplete generation")] + IncompleteGeneration, } diff --git a/router/src/lib.rs b/router/src/lib.rs index de3b7d78..940f06d1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -87,6 +87,14 @@ pub(crate) struct GeneratedText { pub details: Option
, } +#[derive(Serialize)] +pub(crate) struct StreamToken { + pub token: Token, + pub end: bool, + pub finish_reason: Option, + pub generated_text: Option, +} + #[derive(Serialize)] pub(crate) struct ErrorResponse { pub error: String, diff --git a/router/src/server.rs b/router/src/server.rs index 56161597..18d9136d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,8 @@ /// HTTP Server logic use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, Validation, + Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Infer, StreamToken, + Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -10,6 +11,7 @@ use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; use futures::Stream; +use std::convert::Infallible; use std::net::SocketAddr; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; @@ -60,6 +62,7 @@ async fn generate( infer: Extension, req: Json, ) -> Result)> { + let span = tracing::Span::current(); let start_time = Instant::now(); // Inference @@ -111,12 +114,12 @@ async fn generate( ); // Tracing metadata - tracing::Span::current().record("total_time", format!("{:?}", total_time)); - tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); - tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); - tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); - tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); - tracing::Span::current().record("seed", format!("{:?}", response.seed)); + span.record("total_time", format!("{:?}", total_time)); + span.record("validation_time", format!("{:?}", validation_time)); + span.record("queue_time", format!("{:?}", queue_time)); + span.record("inference_time", format!("{:?}", inference_time)); + span.record("time_per_token", format!("{:?}", time_per_token)); + span.record("seed", format!("{:?}", response.seed)); tracing::info!("Output: {}", response.generated_text.text); // Send response @@ -141,57 +144,97 @@ async fn generate( async fn generate_stream( infer: Extension, req: Json, -) -> Sse>> { +) -> Sse>> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + let stream = async_stream::stream! { - let start_time = Instant::now(); - // Inference - let mut response_stream = infer.generate_stream(req.0).await?; + let mut end_reached = false; + let mut error = false; - // Server Side Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { + match infer.generate_stream(req.0).await { + Ok(mut response_stream) => { + // Server Side Event stream + while let Some(response) = response_stream.next().await { match response { - // Prefill is ignored - InferStreamResponse::Prefill(_) => {} - // Yield event for every new token - InferStreamResponse::Token(token) => { - yield Ok(Event::default().json_data(token).unwrap()) - } - // End is used for timings metadata and logging - InferStreamResponse::End { - generated_text, - start, - queued, - } => { - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + // StreamToken + let stream_token = StreamToken { + token, + end: end_reached, + finish_reason: None, + generated_text: None, + }; - // Tracing metadata - tracing::Span::current().record("total_time", format!("{:?}", total_time)); - tracing::Span::current() - .record("validation_time", format!("{:?}", validation_time)); - tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); - tracing::Span::current() - .record("inference_time", format!("{:?}", inference_time)); - tracing::Span::current() - .record("time_per_token", format!("{:?}", time_per_token)); - tracing::info!("Output: {}", generated_text.text); + yield Ok(Event::default().json_data(stream_token).unwrap()) + } + // End is used for timings metadata and logging + InferStreamResponse::End { + token, + generated_text, + start, + queued, + } => { + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{:?}", total_time)); + span + .record("validation_time", format!("{:?}", validation_time)); + span.record("queue_time", format!("{:?}", queue_time)); + span + .record("inference_time", format!("{:?}", inference_time)); + span + .record("time_per_token", format!("{:?}", time_per_token)); + tracing::info!(parent: &span, "Output: {}", generated_text.text); + + // StreamToken + end_reached = true; + let stream_token = StreamToken { + token, + end: end_reached, + finish_reason: Some(generated_text.finish_reason), + generated_text: Some(generated_text.text), + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()) + } + } + } + // Trace and yield error + Err(err) => { + error = true; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) } } } - // Trace and yield error - Err(err) => { - tracing::error!("{}", err.to_string()); - yield Err(err); - } + }, + // Trace and yield error + Err(err) => { + error = true; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) } } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) + } }; Sse::new(stream).keep_alive(KeepAlive::default()) @@ -264,13 +307,14 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } -/// Convert to Axum supported format +/// Convert to Axum supported formats impl From for (StatusCode, Json) { fn from(err: InferError) -> Self { let status_code = match err { InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, }; ( @@ -281,3 +325,13 @@ impl From for (StatusCode, Json) { ) } } + +impl From for Event { + fn from(err: InferError) -> Self { + Event::default() + .json_data(ErrorResponse { + error: err.to_string(), + }) + .unwrap() + } +}