From 017a2a8c2f72c4c30b01eab0da53c00c9c1f7057 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 31 Jan 2023 17:04:00 +0100 Subject: [PATCH] feat: Add token streaming using ServerSideEvents support (#41) --- Cargo.lock | 2 + launcher/Cargo.toml | 2 +- launcher/tests/bloom_560m.json | 6 +- launcher/tests/mt0_base.json | 6 +- proto/generate.proto | 66 ++-- router/Cargo.toml | 2 + router/client/src/client.rs | 28 +- router/client/src/lib.rs | 3 +- router/client/src/sharded_client.rs | 68 ++-- router/src/batcher.rs | 236 ------------- router/src/db.rs | 50 +-- router/src/infer.rs | 353 ++++++++++++++++++++ router/src/lib.rs | 21 +- router/src/server.rs | 307 +++++++++++------ router/src/validation.rs | 73 ++-- server/tests/models/test_bloom.py | 105 +++--- server/tests/models/test_causal_lm.py | 100 +++--- server/tests/models/test_santacoder.py | 30 +- server/tests/models/test_seq2seq_lm.py | 93 +++--- server/text_generation/models/causal_lm.py | 110 +++--- server/text_generation/models/seq2seq_lm.py | 102 +++--- server/text_generation/models/types.py | 59 +++- server/text_generation/server.py | 20 +- 23 files changed, 1084 insertions(+), 758 deletions(-) delete mode 100644 router/src/batcher.rs create mode 100644 router/src/infer.rs diff --git a/Cargo.lock b/Cargo.lock index 1030e8fd9..5b671fa88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1829,6 +1829,7 @@ dependencies = [ name = "text-generation-router" version = "0.1.0" dependencies = [ + "async-stream", "axum", "clap 4.0.22", "futures", @@ -1841,6 +1842,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", ] diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 21d5d3ee6..58df28d96 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -16,4 +16,4 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] } [dev-dependencies] float_eq = "1.0.1" reqwest = { version = "0.11.13", features = ["blocking", "json"] } -serde = "1.0.150" +serde = { version = "1.0.150", features = ["derive"] } diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index d17f1ed4d..a81d1982f 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -3,7 +3,7 @@ "details": { "finish_reason": "length", "generated_tokens": 20, - "tokens": [ + "prefill": [ [ 10264, "Test", @@ -13,7 +13,9 @@ 8821, " request", -11.895094 - ], + ] + ], + "tokens": [ [ 17, ".", diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index 1b7722822..51cb8b5cd 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -3,12 +3,14 @@ "details": { "finish_reason": "length", "generated_tokens": 20, - "tokens": [ + "prefill": [ [ 0, "", null - ], + ] + ], + "tokens": [ [ 259, "", diff --git a/proto/generate.proto b/proto/generate.proto index 81039a7c3..8f431c5cf 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -7,10 +7,10 @@ service TextGenerationService { rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Generate tokens for a batch - rpc Generate (GenerateRequest) returns (GenerateResponse); - /// Generate tokens for a list of cached batches - rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); + /// Prefill batch and decode first token + rpc Prefill (PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode (DecodeRequest) returns (DecodeResponse); } /// Empty request @@ -70,44 +70,60 @@ message Batch { } message GeneratedText { - /// Request - Request request = 1; /// Output - string output_text = 2; + string text = 1; /// Number of generated tokens - uint32 generated_tokens = 3; - /// Tokens - repeated string tokens = 4; - /// Token IDs - repeated uint32 token_ids = 5; - /// Logprobs - repeated float logprobs = 6; + uint32 generated_tokens = 2; /// Finish reason - string finish_reason = 7; + string finish_reason = 3; /// Seed - optional uint64 seed = 8; + optional uint64 seed = 4; } -message GenerateRequest { +message PrefillTokens { + /// Prefill Token IDs + repeated uint32 ids = 1; + /// Prefill Logprobs + repeated float logprobs = 2; + /// Prefill tokens + repeated string texts = 3; +} + +message Generation { + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + PrefillTokens prefill_tokens = 2; + /// Token ID + uint32 token_id = 3; + /// Logprob + float token_logprob = 4; + /// Text + string token_text = 5; + /// Complete generated text + GeneratedText generated_text = 6; +} + +message PrefillRequest { /// Batch Batch batch = 1; } -message GenerateResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; +message PrefillResponse { + /// Generation + repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; } -message GenerateWithCacheRequest { +message DecodeRequest { /// Cached batches repeated Batch batches = 1; } -message GenerateWithCacheResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; +message DecodeResponse { + /// Decodes + repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; -} +} \ No newline at end of file diff --git a/router/Cargo.toml b/router/Cargo.toml index 17724bcc4..3abbc80b2 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -13,6 +13,7 @@ name = "text-generation-router" path = "src/main.rs" [dependencies] +async-stream = "0.3.3" axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } @@ -25,6 +26,7 @@ serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.11" tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 172d0bf73..77a431104 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -70,36 +70,36 @@ impl Client { /// Generate one token for each request in the given batch /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batch /// and the next cached batch #[instrument(skip(self))] - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); + pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); let response = self .stub - .generate(request) - .instrument(info_span!("generate")) + .prefill(request) + .instrument(info_span!("prefill")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generations, response.batch)) } - /// Generate one token for each request in the given cached batch + /// Generate one token for each request in the given cached batches /// - /// Returns a list of generated texts of request that met their stopping criteria + /// Returns Generation for each request in batches /// and the next cached batch #[instrument(skip(self))] - pub async fn generate_with_cache( + pub async fn decode( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateWithCacheRequest { batches }); + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(DecodeRequest { batches }); let response = self .stub - .generate_with_cache(request) - .instrument(info_span!("generate_with_cache")) + .decode(request) + .instrument(info_span!("decode")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generations, response.batch)) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 295b009b4..e0546b16d 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,8 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6c70afca1..56335f923 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, GeneratedText}; +use crate::{Batch, Client, Generation}; use futures::future::join_all; use futures::future::select_all; use tonic::transport::Uri; @@ -37,39 +37,6 @@ impl ShardedClient { Self::from_master_client(master_client).await } - /// Generate one token for each request in the given batch - /// - /// Returns a list of generated texts of request that met their stopping criteria - /// and the next cached batch - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let futures: Vec<_> = self - .clients - .iter_mut() - .map(|client| Box::pin(client.generate(batch.clone()))) - .collect(); - // As soon as we receive one response, we can return as all shards will return the same - let (result, _, _) = select_all(futures).await; - result - } - - /// Generate one token for each request in the given cached batch - /// - /// Returns a list of generated texts of request that met their stopping criteria - /// and the next cached batch - pub async fn generate_with_cache( - &mut self, - batches: Vec, - ) -> Result<(Vec, Option)> { - let futures: Vec<_> = self - .clients - .iter_mut() - .map(|client| Box::pin(client.generate_with_cache(batches.clone()))) - .collect(); - // As soon as we receive one response, we can return as all shards will return the same - let (result, _, _) = select_all(futures).await; - result - } - /// Clear the past generations cache pub async fn clear_cache(&mut self) -> Result<()> { let futures: Vec<_> = self @@ -79,4 +46,37 @@ impl ShardedClient { .collect(); join_all(futures).await.into_iter().collect() } + + /// Generate one token for each request in the given batch + /// + /// Returns Generation for each request in batch + /// and the next cached batch + pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.prefill(batch.clone()))) + .collect(); + // As soon as we receive one response, we can return as all shards will return the same + let (result, _, _) = select_all(futures).await; + result + } + + /// Generate one token for each request in the given cached batches + /// + /// Returns Generation for each request in batches + /// and the next cached batch + pub async fn decode( + &mut self, + batches: Vec, + ) -> Result<(Vec, Option)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.decode(batches.clone()))) + .collect(); + // As soon as we receive one response, we can return as all shards will return the same + let (result, _, _) = select_all(futures).await; + result + } } diff --git a/router/src/batcher.rs b/router/src/batcher.rs deleted file mode 100644 index baf58af40..000000000 --- a/router/src/batcher.rs +++ /dev/null @@ -1,236 +0,0 @@ -/// Batching and inference logic -use crate::{Db, Entry}; -use crate::{ErrorResponse, GenerateRequest}; -use axum::http::StatusCode; -use axum::Json; -use nohash_hasher::IntMap; -use std::future::Future; -use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; -use thiserror::Error; -use tokio::sync::{oneshot, Notify}; -use tokio::time::Instant; -use tracing::instrument; - -/// Batcher -#[derive(Clone)] -pub struct Batcher { - /// Request database - db: Db, - /// Shared state - shared: Arc, -} - -/// Batcher shared state -struct Shared { - /// Batching background Tokio task notifier - batching_task: Notify, -} - -impl Batcher { - pub(crate) fn new( - client: ShardedClient, - max_batch_size: usize, - max_waiting_tokens: usize, - ) -> Self { - // Batcher shared state - let db = Db::new(); - let shared = Arc::new(Shared { - batching_task: Notify::new(), - }); - - // Spawn batching background task that contains all the inference logic - tokio::spawn(batching_task( - client, - max_batch_size, - max_waiting_tokens, - db.clone(), - shared.clone(), - )); - - Self { db, shared } - } - - /// Add a new request to the database and return a future that will generate the text - pub(crate) async fn infer( - &self, - input_length: usize, - request: GenerateRequest, - ) -> Result { - // One shot channel to communicate with the background batching task - let (response_tx, response_rx) = oneshot::channel(); - - // Try to append the request to the database - self.db.append(Entry { - request, - response_tx, - input_length, - time: Instant::now(), - batch_time: None, - }); - - // Notify the background task that we have a new entry in the database that needs - // to be batched - self.shared.batching_task.notify_one(); - - // Await on the response from the background task - // We can safely unwrap as the background task will never drop the sender - response_rx - .await - .unwrap() - .map_err(|err| InferError::GenerationError(err.to_string())) - } -} - -/// Batching logic -/// Will be launched in a background Tokio task -/// -/// Batches requests and sends them to the inference server -#[instrument(skip(client, db, shared))] -async fn batching_task( - mut client: ShardedClient, - max_batch_size: usize, - max_waiting_tokens: usize, - db: Db, - shared: Arc, -) { - // Minimum batch size after which we try to add more requests - let limit_min_batch_size = (max_batch_size / 2) as u32; - - // Infinite loop - loop { - // Wait for a notification from the Batcher struct - shared.batching_task.notified().await; - - // Get the next batch from the DB - // This batch might be smaller than the maximum batch size if there are not enough requests - // waiting in the DB - while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await; - let mut waiting_tokens = 1; - - // We loop until we do not receive any cached batch from the inference server (== until - // all requests have met their stopping criteria) - while let Some(batch) = cached_batch { - // Get current batch info - let batch_size = batch.size; - let mut batches = vec![batch]; - - // If the current batch is too small, we try to add more requests to it - if batch_size <= limit_min_batch_size { - let min_size = match waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - _ if waiting_tokens >= max_waiting_tokens => None, - // Minimum size criteria - _ => Some(limit_min_batch_size as usize), - }; - - // Try to get a new batch - if let Some((mut new_entries, new_batch)) = - db.next_batch(min_size, max_batch_size - batch_size as usize) - { - // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = - wrap_future(client.generate(new_batch), &mut new_entries).await; - // Reset waiting counter - waiting_tokens = 1; - // Extend current batch with the new batch - if let Some(new_cached_batch) = new_cached_batch { - entries.extend(new_entries); - batches.push(new_cached_batch); - } - } - } - - cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await; - waiting_tokens += 1; - } - } - } -} - -/// Wrap a future inside a match statement to handle errors and send the response to the Batcher -async fn wrap_future( - future: impl Future, Option), ClientError>>, - entries: &mut IntMap, -) -> Option { - match future.await { - Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, entries); - next_batch - } - // If we have an error, we discard the whole batch - Err(err) => { - send_error(err, entries); - None - } - } -} - -/// Send errors to the Batcher for all `entries` -fn send_error(error: ClientError, entries: &mut IntMap) { - entries.drain().for_each(|(_, entry)| { - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Err(error.clone())).unwrap_or(()); - }); -} - -/// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, entries: &mut IntMap) { - finished.into_iter().for_each(|output| { - // We can `expect` here as the request id should always be in the entries - let entry = entries - .remove(&output.request.unwrap().id) - .expect("ID not found in entries. This is a bug."); - - let response = InferResponse { - output_text: output.output_text, - generated_tokens: output.generated_tokens, - token_ids: output.token_ids, - tokens: output.tokens, - logprobs: output.logprobs, - finish_reason: output.finish_reason, - seed: output.seed, - queued: entry.time, - start: entry.batch_time.unwrap(), // unwrap is always valid - end: Instant::now(), - }; - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Ok(response)).unwrap_or(()); - }); -} - -#[derive(Debug)] -pub(crate) struct InferResponse { - pub(crate) output_text: String, - pub(crate) generated_tokens: u32, - pub(crate) token_ids: Vec, - pub(crate) tokens: Vec, - pub(crate) logprobs: Vec, - pub(crate) finish_reason: String, - pub(crate) seed: Option, - pub(crate) queued: Instant, - pub(crate) start: Instant, - pub(crate) end: Instant, -} - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), -} - -/// Convert to Axum supported format -impl From for (StatusCode, Json) { - fn from(err: InferError) -> Self { - match err { - InferError::GenerationError(_) => ( - StatusCode::FAILED_DEPENDENCY, - Json(ErrorResponse { - error: err.to_string(), - }), - ), - } - } -} diff --git a/router/src/db.rs b/router/src/db.rs index 442d7b9c0..246e4d5de 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,29 +1,29 @@ /// This code is massively inspired by Tokio mini-redis -use crate::InferResponse; -use crate::{GenerateParameters, GenerateRequest}; +use crate::infer::InferError; +use crate::infer::InferStreamResponse; +use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; -use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, -}; -use tokio::sync::oneshot::Sender; +use text_generation_client::{Batch, Request}; +use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::OwnedSemaphorePermit; use tokio::time::Instant; /// Database entry #[derive(Debug)] pub(crate) struct Entry { /// Request - pub request: GenerateRequest, - /// Response sender to communicate between the Batcher and the batching_task - pub response_tx: Sender>, - /// Number of tokens in the input - pub input_length: usize, + pub request: ValidGenerateRequest, + /// Response sender to communicate between the Infer struct and the batching_task + pub response_tx: UnboundedSender>, /// Instant when this entry was created pub time: Instant, /// Instant when this entry was added to a batch pub batch_time: Option, + /// Permit + pub _permit: OwnedSemaphorePermit, } /// Request Database @@ -71,9 +71,9 @@ impl State { requests.push(Request { id: *id, inputs: entry.request.inputs.clone(), - input_length: entry.input_length as u32, - parameters: Some((&entry.request.parameters).into()), - stopping_parameters: Some(entry.request.parameters.clone().into()), + input_length: entry.request.input_length, + parameters: Some(entry.request.parameters.clone()), + stopping_parameters: Some(entry.request.stopping_parameters.clone()), }); ids.push(*id); @@ -158,25 +158,3 @@ impl Db { None } } - -impl From<&GenerateParameters> for NextTokenChooserParameters { - fn from(parameters: &GenerateParameters) -> Self { - Self { - temperature: parameters.temperature, - top_k: parameters.top_k as u32, - top_p: parameters.top_p, - do_sample: parameters.do_sample, - // FIXME: remove unwrap - seed: parameters.seed.unwrap(), - } - } -} - -impl From for StoppingCriteriaParameters { - fn from(parameters: GenerateParameters) -> Self { - Self { - stop_sequences: parameters.stop, - max_new_tokens: parameters.max_new_tokens, - } - } -} diff --git a/router/src/infer.rs b/router/src/infer.rs new file mode 100644 index 000000000..23e842652 --- /dev/null +++ b/router/src/infer.rs @@ -0,0 +1,353 @@ +/// Batching and inference logic +use crate::validation::{Validation, ValidationError}; +use crate::GenerateRequest; +use crate::{Db, Entry, Token}; +use nohash_hasher::IntMap; +use std::future::Future; +use std::sync::Arc; +use text_generation_client::{ + Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, +}; +use thiserror::Error; +use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; +use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tracing::instrument; + +/// Inference struct +#[derive(Clone)] +pub struct Infer { + /// Validation + validation: Validation, + /// Request database + db: Db, + /// Shared state + shared: Arc, + /// Inference limit + limit_concurrent_requests: Arc, +} + +/// Infer shared state +struct Shared { + /// Batching background Tokio task notifier + batching_task: Notify, +} + +impl Infer { + pub(crate) fn new( + client: ShardedClient, + validation: Validation, + max_batch_size: usize, + max_waiting_tokens: usize, + max_concurrent_requests: usize, + ) -> Self { + // Infer shared state + let db = Db::new(); + let shared = Arc::new(Shared { + batching_task: Notify::new(), + }); + + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + client, + max_batch_size, + max_waiting_tokens, + db.clone(), + shared.clone(), + )); + + // Inference limit with a semaphore + let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); + + Self { + validation, + db, + shared, + limit_concurrent_requests: semaphore, + } + } + + /// Add a new request to the database and return a stream of InferStreamResponse + pub(crate) async fn generate_stream( + &self, + request: GenerateRequest, + ) -> Result>, InferError> { + // Limit concurrent requests by acquiring a permit from the semaphore + // This permit will live as long as Entry + let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; + + // Validate request + let valid_request = self.validation.validate(request).await?; + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + // Append the request to the database + self.db.append(Entry { + request: valid_request, + response_tx, + time: Instant::now(), + batch_time: None, + _permit: permit, + }); + + // Notify the background task that we have a new entry in the database that needs + // to be batched + self.shared.batching_task.notify_one(); + + // Return stream + Ok(UnboundedReceiverStream::new(response_rx)) + } + + /// Add a new request to the database and return a InferResponse + pub(crate) async fn generate( + &self, + request: GenerateRequest, + ) -> Result { + // Create stream + let mut stream = self.generate_stream(request).await?; + + // Return values + let mut result_prefill = Vec::new(); + let mut result_tokens = Vec::new(); + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + // Iterate on stream + while let Some(response) = stream.next().await { + match response? { + // Add prefill tokens + InferStreamResponse::Prefill(tokens) => { + // Create Token objects + // We do that here instead of in the Python code as Rust for loops are faster + result_prefill = tokens + .ids + .into_iter() + .zip(tokens.logprobs.into_iter()) + .zip(tokens.texts.into_iter()) + .map(|((id, logprob), text)| Token(id, text, logprob)) + .collect(); + } + // Push last token + InferStreamResponse::Token(token) => result_tokens.push(token), + // Final message + // Set return values + InferStreamResponse::End { + token, + generated_text, + start, + queued, + } => { + result_tokens.push(token); + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + } + } + + // Check that we received a `InferStreamResponse::End` message + if let (Some(generated_text), Some(queued), Some(start)) = + (result_generated_text, result_queued, result_start) + { + Ok(InferResponse { + prefill: result_prefill, + tokens: result_tokens, + generated_text, + queued, + start, + }) + } else { + Err(InferError::IncompleteGeneration) + } + } +} + +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[instrument(skip(client, db, shared))] +async fn batching_task( + mut client: ShardedClient, + max_batch_size: usize, + max_waiting_tokens: usize, + db: Db, + shared: Arc, +) { + // Minimum batch size after which we try to add more requests + let limit_min_batch_size = (max_batch_size / 2) as u32; + + // Infinite loop + loop { + // Wait for a notification from the Infer struct + shared.batching_task.notified().await; + + // Get the next batch from the DB + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the DB + while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { + let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; + let mut waiting_tokens = 1; + + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) + while let Some(batch) = cached_batch { + // Get current batch info + let batch_size = batch.size; + let mut batches = vec![batch]; + + // If the current batch is too small, we try to add more requests to it + if batch_size <= limit_min_batch_size { + let min_size = match waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + _ if waiting_tokens >= max_waiting_tokens => None, + // Minimum size criteria + _ => Some(limit_min_batch_size as usize), + }; + + // Try to get a new batch + if let Some((mut new_entries, new_batch)) = + db.next_batch(min_size, max_batch_size - batch_size as usize) + { + // Generate one token for this new batch to have the attention past in cache + let new_cached_batch = + wrap_future(client.prefill(new_batch), &mut new_entries).await; + // Reset waiting counter + waiting_tokens = 1; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + entries.extend(new_entries); + batches.push(new_cached_batch); + } + } + } + + cached_batch = wrap_future(client.decode(batches), &mut entries).await; + waiting_tokens += 1; + } + } + } +} + +/// Wrap a future inside a match statement to handle errors and send the responses to Infer +async fn wrap_future( + future: impl Future, Option), ClientError>>, + entries: &mut IntMap, +) -> Option { + match future.await { + Ok((generations, next_batch)) => { + send_generations(generations, entries); + next_batch + } + // If we have an error, we discard the whole batch + Err(err) => { + send_error(err, entries); + None + } + } +} + +/// Send errors to Infer for all `entries` +fn send_error(error: ClientError, entries: &mut IntMap) { + entries.drain().for_each(|(_, entry)| { + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Err(InferError::GenerationError(error.to_string()))) + .unwrap_or(()); + }); +} + +/// Send one or multiple `InferStreamResponse` to Infer for all `entries` +fn send_generations(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { + // Get entry + // We can `expect` here as the request id should always be in the entries + let entry = entries + .get(&generation.request_id) + .expect("ID not found in entries. This is a bug."); + + if let Some(prefill_tokens) = generation.prefill_tokens { + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(prefill_tokens))) + .unwrap_or(()); + } + + // Create last Token + let token = Token( + generation.token_id, + generation.token_text, + generation.token_logprob, + ); + + if let Some(generated_text) = generation.generated_text { + // Remove entry as this is the last message + // We can `expect` here as the request id should always be in the entries + let entry = entries + .remove(&generation.request_id) + .expect("ID not found in entries. This is a bug."); + + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Ok(InferStreamResponse::End { + token, + generated_text, + queued: entry.time, + start: entry.batch_time.unwrap(), + })) + .unwrap_or(()); + } else { + // Send message + // unwrap_or is valid here as we don't care if the receiver is gone. + entry + .response_tx + .send(Ok(InferStreamResponse::Token(token))) + .unwrap_or(()); + } + }); +} + +#[derive(Debug)] +pub(crate) enum InferStreamResponse { + // Optional first message + Prefill(PrefillTokens), + // Intermediate messages + Token(Token), + // Last message + End { + token: Token, + generated_text: GeneratedText, + start: Instant, + queued: Instant, + }, +} + +#[derive(Debug)] +pub(crate) struct InferResponse { + pub(crate) prefill: Vec, + pub(crate) tokens: Vec, + pub(crate) generated_text: GeneratedText, + pub(crate) queued: Instant, + pub(crate) start: Instant, +} + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), + #[error("Model is overloaded")] + Overloaded(#[from] TryAcquireError), + #[error("Input validation error: {0}")] + ValidationError(#[from] ValidationError), + #[error("Incomplete generation")] + IncompleteGeneration, +} diff --git a/router/src/lib.rs b/router/src/lib.rs index 1aeac302c..beab71380 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,11 +1,11 @@ /// Text Generation Inference Webserver -mod batcher; mod db; +mod infer; pub mod server; mod validation; -use batcher::{Batcher, InferResponse}; use db::{Db, Entry}; +use infer::Infer; use serde::{Deserialize, Serialize}; use validation::Validation; @@ -69,21 +69,34 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Debug, Serialize)] +pub struct Token(u32, String, f32); + #[derive(Serialize)] pub(crate) struct Details { pub finish_reason: String, pub generated_tokens: u32, pub seed: Option, - pub tokens: Vec<(u32, String, f32)>, + #[serde(skip_serializing_if = "Option::is_none")] + pub prefill: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option>, } #[derive(Serialize)] -pub(crate) struct GeneratedText { +pub(crate) struct GenerateResponse { pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, } +#[derive(Serialize)] +pub(crate) struct StreamResponse { + pub token: Token, + pub generated_text: Option, + pub details: Option
, +} + #[derive(Serialize)] pub(crate) struct ErrorResponse { pub error: String, diff --git a/router/src/server.rs b/router/src/server.rs index 86041b96e..ef3782d6c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,71 +1,54 @@ +/// HTTP Server logic +use crate::infer::{InferError, InferStreamResponse}; use crate::{ - Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, + Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, + StreamResponse, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; +use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; +use futures::Stream; +use std::convert::Infallible; use std::net::SocketAddr; -use std::sync::Arc; use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; -use tokio::sync::Semaphore; use tokio::time::Instant; +use tokio_stream::StreamExt; use tracing::instrument; -// Server shared state -#[derive(Clone)] -struct ServerState { - validation: Validation, - batcher: Batcher, - limit_concurrent_requests: Arc, -} - /// Health check method -#[instrument(skip(state), fields(time, time_per_token))] -async fn health(state: Extension) -> Result<(), (StatusCode, Json)> { +#[instrument(skip(infer))] +async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { // TODO: while this is the best health check we can do, it is a bit on the heavy side and might // be a bit too slow for a health check. // What we should do instead if check if the gRPC channels are still healthy. - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { - ( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - error: "Model is overloaded".to_string(), - }), - ) - })?; - // Send a small inference request - state - .batcher - .infer( - 1, - GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - do_sample: false, - max_new_tokens: 1, - stop: vec![], - details: false, - seed: None, - }, + infer + .generate(GenerateRequest { + inputs: "liveness".to_string(), + parameters: GenerateParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + do_sample: false, + max_new_tokens: 1, + stop: vec![], + details: false, + seed: None, }, - ) + }) .await?; Ok(()) } /// Generate method #[instrument( - skip(state), + skip(infer), fields( total_time, validation_time, @@ -76,56 +59,28 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, + infer: Extension, req: Json, ) -> Result)> { + let span = tracing::Span::current(); let start_time = Instant::now(); - // Limit concurrent requests by acquiring a permit from the semaphore - let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { - tracing::error!("Model is overloaded"); - ( - StatusCode::TOO_MANY_REQUESTS, - Json(ErrorResponse { - error: "Model is overloaded".to_string(), - }), - ) - })?; - - // Validate request - let details = req.0.parameters.details; - let (input_length, validated_request) = - state.validation.validate(req.0).await.map_err(|err| { - tracing::error!("{}", err.to_string()); - err - })?; // Inference - let response = state - .batcher - .infer(input_length, validated_request) - .await - .map_err(|err| { - tracing::error!("{}", err.to_string()); - err - })?; + let details = req.0.parameters.details; + let response = infer.generate(req.0).await.map_err(|err| { + tracing::error!("{}", err.to_string()); + err + })?; // Token details let details = match details { - true => { - let tokens = response - .token_ids - .into_iter() - .zip(response.tokens.into_iter()) - .zip(response.logprobs.into_iter()) - .map(|((id, text), logprob)| (id, text, logprob)) - .collect(); - Some(Details { - seed: response.seed, - finish_reason: response.finish_reason, - generated_tokens: response.generated_tokens, - tokens, - }) - } + true => Some(Details { + finish_reason: response.generated_text.finish_reason, + generated_tokens: response.generated_text.generated_tokens, + prefill: Some(response.prefill), + tokens: Some(response.tokens), + seed: response.generated_text.seed, + }), false => None, }; @@ -133,8 +88,8 @@ async fn generate( let total_time = start_time.elapsed(); let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; - let inference_time = response.end - response.start; - let time_per_token = inference_time / response.generated_tokens; + let inference_time = Instant::now() - response.start; + let time_per_token = inference_time / response.generated_text.generated_tokens; // Headers let mut headers = HeaderMap::new(); @@ -160,22 +115,143 @@ async fn generate( ); // Tracing metadata - tracing::Span::current().record("total_time", format!("{:?}", total_time)); - tracing::Span::current().record("validation_time", format!("{:?}", validation_time)); - tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); - tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); - tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); - tracing::Span::current().record("seed", format!("{:?}", response.seed)); - tracing::info!("Output: {}", response.output_text); + span.record("total_time", format!("{:?}", total_time)); + span.record("validation_time", format!("{:?}", validation_time)); + span.record("queue_time", format!("{:?}", queue_time)); + span.record("inference_time", format!("{:?}", inference_time)); + span.record("time_per_token", format!("{:?}", time_per_token)); + span.record("seed", format!("{:?}", response.generated_text.seed)); + tracing::info!("Output: {}", response.generated_text.text); // Send response - let response = vec![GeneratedText { - generated_text: response.output_text, + let response = vec![GenerateResponse { + generated_text: response.generated_text.text, details, }]; Ok((headers, Json(response))) } +/// Generate stream method +#[instrument( + skip(infer), + fields( + total_time, + validation_time, + queue_time, + inference_time, + time_per_token + ) +)] +async fn generate_stream( + infer: Extension, + req: Json, +) -> Sse>> { + let span = tracing::Span::current(); + let start_time = Instant::now(); + + let stream = async_stream::stream! { + // Inference + let mut end_reached = false; + let mut error = false; + let details = req.0.parameters.details; + + match infer.generate_stream(req.0).await { + Ok(mut response_stream) => { + // Server Side Event stream + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + // StreamResponse + let stream_token = StreamResponse { + token, + generated_text: None, + details: None, + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()) + } + // Yield event for last token and compute timings + InferStreamResponse::End { + token, + generated_text, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(Details { + finish_reason: generated_text.finish_reason, + generated_tokens: generated_text.generated_tokens, + prefill: None, + tokens: None, + seed: generated_text.seed, + }), + false => None, + }; + + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{:?}", total_time)); + span + .record("validation_time", format!("{:?}", validation_time)); + span.record("queue_time", format!("{:?}", queue_time)); + span + .record("inference_time", format!("{:?}", inference_time)); + span + .record("time_per_token", format!("{:?}", time_per_token)); + tracing::info!(parent: &span, "Output: {}", generated_text.text); + + // StreamResponse + end_reached = true; + let stream_token = StreamResponse { + token, + generated_text: Some(generated_text.text), + details + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()) + } + } + } + // Trace and yield error + Err(err) => { + error = true; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) + } + } + } + }, + // Trace and yield error + Err(err) => { + error = true; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) + } + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + tracing::error!("{}", err.to_string()); + yield Ok(Event::from(err)) + } + }; + + Sse::new(stream).keep_alive(KeepAlive::default()) +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -189,21 +265,23 @@ pub async fn run( addr: SocketAddr, ) { // Create state - let batcher = Batcher::new(client, max_batch_size, max_waiting_tokens); let validation = Validation::new(validation_workers, tokenizer, max_input_length); - let shared_state = ServerState { + let infer = Infer::new( + client, validation, - batcher, - limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), - }; + max_batch_size, + max_waiting_tokens, + max_concurrent_requests, + ); // Create router let app = Router::new() .route("/", post(generate)) .route("/generate", post(generate)) + .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) - .layer(Extension(shared_state.clone())); + .layer(Extension(infer)); // Run server axum::Server::bind(&addr) @@ -240,3 +318,32 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } + +/// Convert to Axum supported formats +impl From for (StatusCode, Json) { + fn from(err: InferError) -> Self { + let status_code = match err { + InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY, + InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS, + InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, + }; + + ( + status_code, + Json(ErrorResponse { + error: err.to_string(), + }), + ) + } +} + +impl From for Event { + fn from(err: InferError) -> Self { + Event::default() + .json_data(ErrorResponse { + error: err.to_string(), + }) + .unwrap() + } +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 674635d37..39da2a834 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,9 +1,8 @@ /// Payload validation logic -use crate::{ErrorResponse, GenerateRequest}; -use axum::http::StatusCode; -use axum::Json; +use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; use rand::Rng; +use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; @@ -40,7 +39,7 @@ impl Validation { pub(crate) async fn validate( &self, request: GenerateRequest, - ) -> Result<(usize, GenerateRequest), ValidationError> { + ) -> Result { // Create response channel let (sender, receiver) = oneshot::channel(); // Send request to the background validation task @@ -106,11 +105,11 @@ fn validation_worker( } fn validate( - mut request: GenerateRequest, + request: GenerateRequest, tokenizer: &Tokenizer, max_input_length: usize, rng: &mut ThreadRng, -) -> Result<(usize, GenerateRequest), ValidationError> { +) -> Result { if request.parameters.temperature <= 0.0 { return Err(ValidationError::Temperature); } @@ -131,19 +130,48 @@ fn validate( } // If seed is None, assign a random one - if request.parameters.seed.is_none() { - request.parameters.seed = Some(rng.gen()); - } + let seed = match request.parameters.seed { + None => rng.gen(), + Some(seed) => seed, + }; // Get the number of tokens in the input match tokenizer.encode(request.inputs.clone(), true) { - Ok(inputs) => { - let input_length = inputs.len(); + Ok(encoding) => { + let input_length = encoding.len(); if input_length > max_input_length { Err(ValidationError::InputLength(input_length, max_input_length)) } else { - Ok((input_length, request)) + // Return ValidGenerateRequest + let GenerateParameters { + temperature, + top_k, + top_p, + do_sample, + max_new_tokens, + stop: stop_sequences, + .. + } = request.parameters; + + let parameters = NextTokenChooserParameters { + temperature, + top_k: top_k as u32, + top_p, + do_sample, + seed, + }; + let stopping_parameters = StoppingCriteriaParameters { + max_new_tokens, + stop_sequences, + }; + + Ok(ValidGenerateRequest { + inputs: request.inputs, + input_length: input_length as u32, + parameters, + stopping_parameters, + }) } } Err(err) => Err(ValidationError::Tokenizer(err.to_string())), @@ -152,9 +180,17 @@ fn validate( type ValidationRequest = ( GenerateRequest, - oneshot::Sender>, + oneshot::Sender>, ); +#[derive(Debug)] +pub(crate) struct ValidGenerateRequest { + pub inputs: String, + pub input_length: u32, + pub parameters: NextTokenChooserParameters, + pub stopping_parameters: StoppingCriteriaParameters, +} + #[derive(Error, Debug)] pub enum ValidationError { #[error("temperature must be strictly positive")] @@ -172,14 +208,3 @@ pub enum ValidationError { #[error("tokenizer error {0}")] Tokenizer(String), } - -impl From for (StatusCode, Json) { - fn from(err: ValidationError) -> Self { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - }), - ) - } -} diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 1a788ce52..9f96efc32 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -91,9 +91,9 @@ def test_causal_lm_batch_type(default_bloom): def test_causal_lm_generate_token(default_bloom, default_bloom_batch): sequence_length = len(default_bloom_batch.all_input_ids[0]) - generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) + generations, next_batch = default_bloom.generate_token(default_bloom_batch) - assert generated_texts == [] + assert len(generations) == len(default_bloom_batch) assert isinstance(next_batch, CausalLMBatch) assert not next_batch.keys_head_dim_last @@ -122,24 +122,30 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): assert all( [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 10264 for generation in generations]) + assert all([generation.token_text == "Test" for generation in generations]) + assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): next_batch = default_bloom_batch for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_bloom_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_bloom_batch.requests[0] + assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -152,17 +158,19 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(default_multi_requests_bloom_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTest" - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert len(generations) == 2 + assert generations[1].generated_text.text == "TestTestTestTestTestTest" assert ( - generated_texts[0].generated_tokens + generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[1].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -171,19 +179,22 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].generated_tokens + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -243,17 +254,19 @@ def test_batch_concatenate( for _ in range( default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "TestTestTestTestTestTest" - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert len(generations) == 3 + assert generations[2].generated_text.text == "TestTestTestTestTestTest" assert ( - generated_texts[0].generated_tokens + generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id + ) + assert ( + generations[2].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -262,19 +275,20 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 + assert len(generations) == 2 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_bloom_batch.requests[0] + assert generations[0].request_id == default_bloom_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -284,18 +298,21 @@ def test_batch_concatenate( - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4 ): - generated_texts, next_batch = default_bloom.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_bloom.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_bloom.generate_token(next_batch) + generations, next_batch = default_bloom.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" + generations[0].generated_text.text + == "TestTestTestTestTestTestTestTestTestTestTest" ) - assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].generated_tokens + generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id + ) + assert ( + generations[0].generated_text.generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index bedb65baa..f9762b304 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -88,11 +88,9 @@ def test_causal_lm_batch_type(default_causal_lm): def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - generated_texts, next_batch = default_causal_lm.generate_token( - default_causal_lm_batch - ) + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) - assert generated_texts == [] + assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) assert len(next_batch.all_input_ids) == next_batch.size @@ -121,6 +119,11 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): assert all( [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 13 for generation in generations]) + assert all([generation.token_text == "." for generation in generations]) + assert generations[0].request_id == 0 def test_causal_lm_generate_token_completion( @@ -128,18 +131,17 @@ def test_causal_lm_generate_token_completion( ): next_batch = default_causal_lm_batch for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." - assert generated_texts[0].request == default_causal_lm_batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -152,19 +154,20 @@ def test_causal_lm_generate_token_completion_multi( for i in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784)" + assert len(generations) == 2 + assert generations[1].generated_text.text == "Test.java:784)" assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + generations[1].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generated_texts[0].generated_tokens + generations[1].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -173,19 +176,20 @@ def test_causal_lm_generate_token_completion_multi( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -244,19 +248,20 @@ def test_batch_concatenate( for _ in range( default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784)" + assert len(generations) == 3 + assert generations[2].generated_text.text == "Test.java:784)" assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + generations[2].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generated_texts[0].generated_tokens + generations[2].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -265,17 +270,17 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." - assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert len(generations) == 2 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." + assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -285,18 +290,19 @@ def test_batch_concatenate( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4 ): - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_causal_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + generations, next_batch = default_causal_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." + assert len(generations) == 1 + assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert ( - generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index acebec042..1b69477d6 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -50,18 +50,17 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_santacoder.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "def test_get_all_users_with_" - assert generated_texts[0].request == batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert len(generations) == 1 + assert generations[0].generated_text.text == "def test_get_all_users_with_" + assert generations[0].request_id == batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == batch.stopping_criterias[0].max_new_tokens ) @@ -76,20 +75,19 @@ def test_fim_santacoder_generate_token_completion( next_batch = batch for _ in range(batch.stopping_criterias[0].max_new_tokens - 1): - generated_texts, next_batch = default_santacoder.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_santacoder.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_santacoder.generate_token(next_batch) + generations, next_batch = default_santacoder.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 + assert len(generations) == 1 assert ( - generated_texts[0].output_text + generations[0].generated_text.text == """defworldineProperty(exports, "__esModule", { value""" ) - assert generated_texts[0].request == batch.requests[0] - assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) + assert generations[0].request_id == batch.requests[0].id assert ( - generated_texts[0].generated_tokens + generations[0].generated_text.generated_tokens == batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index de1a48297..22c6ac9cd 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -99,11 +99,11 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm): def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): sequence_length = len(default_seq2seq_lm_batch.input_ids[0]) - generated_texts, next_batch = default_seq2seq_lm.generate_token( + generations, next_batch = default_seq2seq_lm.generate_token( default_seq2seq_lm_batch ) - assert generated_texts == [] + assert len(generations) == len(next_batch) assert isinstance(next_batch, Seq2SeqLMBatch) assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) @@ -145,6 +145,11 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) for p in next_batch.past_key_values ] ) + assert all([generation.generated_text is None for generation in generations]) + assert all([len(generation.prefill_tokens) == 1 for generation in generations]) + assert all([generation.token_id.item() == 259 for generation in generations]) + assert all([generation.token_text == "" for generation in generations]) + assert generations[0].request_id == 0 def test_seq2seq_lm_generate_token_completion( @@ -152,16 +157,16 @@ def test_seq2seq_lm_generate_token_completion( ): next_batch = default_seq2seq_lm_batch for _ in range(6): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" - assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].generated_tokens == 7 + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 def test_seq2seq_lm_generate_token_completion_multi( @@ -170,33 +175,33 @@ def test_seq2seq_lm_generate_token_completion_multi( next_batch = default_multi_requests_seq2seq_lm_batch for i in range(4): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few " + assert len(generations) == 2 + assert generations[1].generated_text.text == "a few " assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[1] + generations[1].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id ) - assert generated_texts[0].generated_tokens == 5 + assert generations[1].generated_text.generated_tokens == 5 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id ) - assert generated_texts[0].generated_tokens == 7 + assert generations[0].generated_text.generated_tokens == 7 def test_batch_concatenate( @@ -291,35 +296,35 @@ def test_batch_concatenate( ) for _ in range(3): - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) - assert generated_texts == [] + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert len(generations) == len(next_batch) - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few " + assert len(generations) == 3 + assert generations[2].generated_text.text == "a few " assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[1] + generations[2].request_id + == default_multi_requests_seq2seq_lm_batch.requests[1].id ) - assert generated_texts[0].generated_tokens == 5 + assert generations[2].generated_text.generated_tokens == 5 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" - assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].generated_tokens == 7 + assert len(generations) == 2 + assert generations[0].generated_text.text == "a few weeks" + assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id + assert generations[0].generated_text.generated_tokens == 7 - generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None - assert len(generated_texts) == 1 - assert generated_texts[0].output_text == "a few weeks" + assert len(generations) == 1 + assert generations[0].generated_text.text == "a few weeks" assert ( - generated_texts[0].request - == default_multi_requests_seq2seq_lm_batch.requests[0] + generations[0].request_id + == default_multi_requests_seq2seq_lm_batch.requests[0].id ) - assert generated_texts[0].generated_tokens == 7 + assert generations[0].generated_text.generated_tokens == 7 diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index bf49d1347..31996e065 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch +from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -23,7 +23,6 @@ class CausalLMBatch(Batch): # All tokens all_input_ids: List[torch.Tensor] - all_logprobs: List[Optional[torch.Tensor]] # Lengths of all generations present in the batch input_lengths: List[int] @@ -57,7 +56,6 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] input_lengths = [] - all_logprobs = [] # Parse batch for r in pb.requests: @@ -67,7 +65,6 @@ class CausalLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) - all_logprobs.append(None) pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( @@ -89,7 +86,6 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=None, all_input_ids=all_input_ids, - all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -107,7 +103,6 @@ class CausalLMBatch(Batch): requests = [] input_lengths = [] all_input_ids = [] - all_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -124,7 +119,6 @@ class CausalLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) - all_logprobs.extend(batch.all_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -225,7 +219,6 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=past_key_values, all_input_ids=all_input_ids, - all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -234,6 +227,9 @@ class CausalLMBatch(Batch): keys_head_dim_last=batches[0].keys_head_dim_last, ) + def __len__(self): + return len(self.requests) + class CausalLM(Model): def __init__(self, model_name: str, quantize=False): @@ -289,7 +285,7 @@ class CausalLM(Model): def generate_token( self, batch: CausalLMBatch - ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: + ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -309,14 +305,13 @@ class CausalLM(Model): next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] - next_batch_all_logprobs = [] # Metadata next_batch_size = 0 next_batch_max_sequence_length = 0 - # Finished requests - generated_texts: List[GeneratedText] = [] + # Results + generations: List[Generation] = [] # Zipped iterator iterator = zip( @@ -326,7 +321,6 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, - batch.all_logprobs, ) # For each member of the batch @@ -337,44 +331,36 @@ class CausalLM(Model): next_token_chooser, stopping_criteria, all_input_ids, - all_logprobs, ) in enumerate(iterator): # Select next token tokens, logprobs = next_token_chooser(all_input_ids, logits) - next_token = tokens[-1].view(1, 1) + next_token_id = tokens[-1].view(1, 1) # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token]) + all_input_ids = torch.cat([all_input_ids, next_token_id]) new_input_length = input_length + 1 - if all_logprobs is None: - # logprobs of all prompt tokens (except the first one) and the generated token - all_logprobs = logprobs.gather(1, all_input_ids[1:]) - else: - # logprob of the generated token - next_token_logprob = logprobs[-1, next_token] - all_logprobs = torch.cat([all_logprobs, next_token_logprob]) + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token.squeeze(), - self.tokenizer.decode( - next_token.squeeze(), clean_up_tokenization_spaces=False - ), + next_token_id_squeezed, + next_token_text, ) + if stop: # Decode generated tokens generated_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :, 0] ) output_text = request.inputs + generated_text - # Slice with input_length to remove padding - token_ids = all_input_ids[-new_input_length:] - tokens = self.tokenizer.batch_decode(token_ids) - # Add NaN for the first prompt token - logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze( - 1 - ).tolist() # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -382,39 +368,58 @@ class CausalLM(Model): else: seed = None - # Add to the list of finished generations with the original request - generated_texts.append( - GeneratedText( - request=request, - output_text=output_text, - generated_tokens=stopping_criteria.current_tokens, - tokens=tokens, - token_ids=token_ids.squeeze(1).tolist(), - logprobs=logprobs, - reason=reason, - seed=seed, - ) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed ) - # add to the next batch else: + # Keep request in the batch + generated_text = None next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token) + next_batch_input_ids.append(next_token_id) next_batch_all_input_ids.append(all_input_ids) - next_batch_all_logprobs.append(all_logprobs) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length ) + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, all_input_ids[1:] + ).squeeze(1)[-new_input_length:-1].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) + + generations.append(generation) + # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generated_texts, None + return generations, None next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if generated_texts: + if len(next_batch_keep_indices) != len(batch): # Apply indices to attention mask, past key values and other items that need to be cached next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_position_ids = batch.position_ids[next_batch_keep_indices] @@ -461,7 +466,6 @@ class CausalLM(Model): position_ids=next_batch_position_ids, past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, - all_logprobs=next_batch_all_logprobs, input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, @@ -469,4 +473,4 @@ class CausalLM(Model): max_sequence_length=next_batch_max_sequence_length, keys_head_dim_last=batch.keys_head_dim_last, ) - return generated_texts, next_batch + return generations, next_batch diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 26ebc7d7a..6d5dc22e1 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokeniz from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch +from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -30,7 +30,6 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] - decoder_logprobs: List[Optional[torch.Tensor]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -64,7 +63,6 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] - decoder_logprobs = [] # Parse batch for r in pb.requests: @@ -77,7 +75,6 @@ class Seq2SeqLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) - decoder_logprobs.append(None) # Tokenize batch pad_to_multiple_of = 8 if device.type == "cuda" else None @@ -102,7 +99,6 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -125,7 +121,6 @@ class Seq2SeqLMBatch(Batch): requests = [] input_lengths = [] decoder_input_lengths = [] - decoder_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -146,7 +141,6 @@ class Seq2SeqLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) - decoder_logprobs.extend(batch.decoder_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -283,7 +277,6 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -291,6 +284,9 @@ class Seq2SeqLMBatch(Batch): max_decoder_input_length=max_decoder_input_length, ) + def __len__(self): + return len(self.requests) + class Seq2SeqLM(Model): def __init__(self, model_name: str, quantize=False): @@ -364,7 +360,7 @@ class Seq2SeqLM(Model): def generate_token( self, batch: Seq2SeqLMBatch - ) -> Tuple[List[GeneratedText], Optional[Seq2SeqLMBatch]]: + ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -386,7 +382,6 @@ class Seq2SeqLM(Model): next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] - next_batch_decoder_logprobs = [] # Metadata next_batch_size = 0 @@ -394,14 +389,13 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length = 0 # Finished requests - generated_texts: List[GeneratedText] = [] + generations: List[Generation] = [] # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.decoder_input_lengths, - batch.decoder_logprobs, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -414,7 +408,6 @@ class Seq2SeqLM(Model): request, input_length, decoder_input_length, - decoder_logprobs, logits, next_token_chooser, stopping_criteria, @@ -422,35 +415,28 @@ class Seq2SeqLM(Model): decoder_input_ids, ) in enumerate(iterator): # Select next token - next_token, logprobs = next_token_chooser(decoder_input_ids, logits) + next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) # Append next token to decoder tokens - decoder_input_ids = torch.cat([decoder_input_ids, next_token]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) new_decoder_input_length = decoder_input_length + 1 - next_token_logprob = logprobs[-1, next_token] - if decoder_logprobs is None: - decoder_logprobs = next_token_logprob - else: - decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.tokenizer.decode( + next_token_id_squeezed, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token.squeeze(), - self.tokenizer.decode( - next_token.squeeze(), clean_up_tokenization_spaces=False - ), - ) + stop, reason = stopping_criteria(next_token_id, next_token_text) + if stop: # Slice with decoder_input_length to remove padding # Decode all tokens - token_ids = decoder_input_ids[-new_decoder_input_length:] - output_text = self.decode(token_ids) - tokens = self.tokenizer.batch_decode(token_ids) - # Add NaN for the bos token - logprobs = [float("nan")] + decoder_logprobs[ - -decoder_input_length: - ].tolist() + output_text = self.decode(decoder_input_ids[-new_decoder_input_length:]) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -458,27 +444,17 @@ class Seq2SeqLM(Model): else: seed = None - # Add to the list of finished generations with the original request - generated_texts.append( - GeneratedText( - request=request, - output_text=output_text, - generated_tokens=stopping_criteria.current_tokens, - tokens=tokens, - token_ids=token_ids.tolist(), - logprobs=logprobs, - reason=reason, - seed=seed, - ) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed ) - # add to the next batch else: + # Keep request in the batch + generated_text = None next_batch_keep_indices.append(i) next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_size += 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) - next_batch_decoder_logprobs.append(decoder_logprobs) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -486,14 +462,39 @@ class Seq2SeqLM(Model): next_batch_max_decoder_input_length, new_decoder_input_length ) + # Prefill + if stopping_criteria.current_tokens == 1: + prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, [float("nan")], prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_squeezed, + next_token_logprob, + next_token_text, + generated_text, + ) + + generations.append(generation) + # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generated_texts, None + return generations, None next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if generated_texts: + if len(next_batch_keep_indices) != len(batch): # Apply indices to attention mask, past key values and other items that need to be cached next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] @@ -551,11 +552,10 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, - decoder_logprobs=next_batch_decoder_logprobs, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_input_length=next_batch_max_input_length, max_decoder_input_length=next_batch_max_decoder_input_length, ) - return generated_texts, next_batch + return generations, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 4ee3cb327..30cd716af 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -29,26 +29,61 @@ class Batch(ABC): def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError + @abstractmethod + def __len__(self): + raise NotImplementedError + @dataclass class GeneratedText: - request: generate_pb2.Request - output_text: str + text: str generated_tokens: int - tokens: List[str] - token_ids: List[int] - logprobs: List[float] - reason: str + finish_reason: str seed: Optional[int] def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( - request=self.request, - output_text=self.output_text, + text=self.text, generated_tokens=self.generated_tokens, - tokens=self.tokens, - token_ids=self.token_ids, - logprobs=self.logprobs, - finish_reason=self.reason, + finish_reason=self.finish_reason, seed=self.seed, ) + + +@dataclass +class PrefillTokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + + def to_pb(self) -> generate_pb2.PrefillTokens: + return generate_pb2.PrefillTokens( + ids=self.token_ids, logprobs=self.logprobs, texts=self.texts + ) + + def __len__(self): + return len(self.token_ids) + + +@dataclass +class Generation: + request_id: int + prefill_tokens: Optional[PrefillTokens] + token_id: int + token_logprob: float + token_text: str + generated_text: Optional[GeneratedText] + + def to_pb(self) -> generate_pb2.Generation: + return generate_pb2.Generation( + request_id=self.request_id, + prefill_tokens=self.prefill_tokens.to_pb() + if self.prefill_tokens is not None + else None, + token_id=self.token_id, + token_logprob=self.token_logprob, + token_text=self.token_text, + generated_text=self.generated_text.to_pb() + if self.generated_text is not None + else None, + ) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index 5fd3072e8..a2bad8a73 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -27,22 +27,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.clear() return generate_pb2.ClearCacheResponse() - async def Generate(self, request, context): + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.device ) - generated_texts, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.GenerateResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts - ], + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, ) - async def GenerateWithCache(self, request, context): + async def Decode(self, request, context): if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -58,13 +56,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] - generated_texts, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.GenerateWithCacheResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts - ], + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, )