From d917ae895517668b73de2aa20b96f8a2cb4066f8 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 27 Jan 2023 19:46:58 +0100 Subject: [PATCH] wip --- Cargo.lock | 2 + proto/generate.proto | 66 +++++---- router/Cargo.toml | 2 + router/client/src/client.rs | 22 +-- router/client/src/lib.rs | 3 +- router/client/src/sharded_client.rs | 12 +- router/src/batcher.rs | 161 +++++++++++++++------ router/src/db.rs | 9 +- router/src/lib.rs | 7 +- router/src/server.rs | 77 +++++++--- server/Makefile | 2 +- server/text_generation/models/causal_lm.py | 152 +++++++++---------- server/text_generation/models/types.py | 60 +++++--- server/text_generation/server.py | 20 +-- 14 files changed, 370 insertions(+), 225 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33f5d181..4c44820b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1829,6 +1829,7 @@ dependencies = [ name = "text-generation-router" version = "0.1.0" dependencies = [ + "async-stream", "axum", "clap 4.0.22", "futures", @@ -1840,6 +1841,7 @@ dependencies = [ "thiserror", "tokenizers", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", ] diff --git a/proto/generate.proto b/proto/generate.proto index 921bd5c0..32ec9681 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -7,10 +7,10 @@ service TextGenerationService { rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} /// Empties batch cache rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Generate tokens for a batch - rpc Generate (GenerateRequest) returns (GenerateResponse); - /// Generate tokens for a list of cached batches - rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); + /// Prefill batch and decode first token + rpc Prefill (PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode (DecodeRequest) returns (DecodeResponse); } /// Empty request @@ -70,44 +70,60 @@ message Batch { } message GeneratedText { - /// Request - Request request = 1; /// Output - string output_text = 2; + string text = 1; /// Number of generated tokens - uint32 generated_tokens = 3; - /// Tokens - repeated string tokens = 4; - /// Token IDs - repeated uint32 token_ids = 5; - /// Logprobs - repeated float logprobs = 6; + uint32 generated_tokens = 2; /// Finish reason - string finish_reason = 7; + string finish_reason = 3; /// Seed - optional uint64 seed = 8; + optional uint64 seed = 4; } -message GenerateRequest { +message PrefillTokens { + /// Prefill Token IDs + repeated uint32 ids = 1; + /// Prefill Logprobs + repeated float logprobs = 2; + /// Prefill tokens + repeated string texts = 3; +} + +message Generation { + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + PrefillTokens prefill_tokens = 2; + /// Token ID + uint32 token_id = 3; + /// Logprob + float token_logprob = 4; + /// Text + string token_text = 5; + /// Complete generated text + GeneratedText generated_text = 6; +} + +message PrefillRequest { /// Batch Batch batch = 1; } -message GenerateResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; +message PrefillResponse { + /// Generation + repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; } -message GenerateWithCacheRequest { +message DecodeRequest { /// Cached batches repeated Batch batches = 1; } -message GenerateWithCacheResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; +message DecodeResponse { + /// Decodes + repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; -} +} \ No newline at end of file diff --git a/router/Cargo.toml b/router/Cargo.toml index 546f127f..d30d3b48 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -13,6 +13,7 @@ name = "text-generation-router" path = "src/main.rs" [dependencies] +async-stream = "0.3.3" axum = { version = "0.5.16", features = ["json", "serde_json"] } text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } @@ -24,6 +25,7 @@ serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.11" tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 172d0bf7..01bf1f17 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -73,15 +73,15 @@ impl Client { /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); + pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); let response = self .stub - .generate(request) - .instrument(info_span!("generate")) + .prefill(request) + .instrument(info_span!("prefill")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generations, response.batch)) } /// Generate one token for each request in the given cached batch @@ -89,17 +89,17 @@ impl Client { /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn generate_with_cache( + pub async fn decode( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateWithCacheRequest { batches }); + ) -> Result<(Vec, Option)> { + let request = tonic::Request::new(DecodeRequest { batches }); let response = self .stub - .generate_with_cache(request) - .instrument(info_span!("generate_with_cache")) + .decode(request) + .instrument(info_span!("decode")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generations, response.batch)) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 295b009b..ec90103b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,8 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, GeneratedText, Generation, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6c70afca..015d431d 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, GeneratedText}; +use crate::{Batch, Client, Generation}; use futures::future::join_all; use futures::future::select_all; use tonic::transport::Uri; @@ -41,11 +41,11 @@ impl ShardedClient { /// /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.generate(batch.clone()))) + .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); // As soon as we receive one response, we can return as all shards will return the same let (result, _, _) = select_all(futures).await; @@ -56,14 +56,14 @@ impl ShardedClient { /// /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch - pub async fn generate_with_cache( + pub async fn decode( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option)> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| Box::pin(client.generate_with_cache(batches.clone()))) + .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); // As soon as we receive one response, we can return as all shards will return the same let (result, _, _) = select_all(futures).await; diff --git a/router/src/batcher.rs b/router/src/batcher.rs index baf58af4..d85efedf 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,15 +1,17 @@ /// Batching and inference logic -use crate::{Db, Entry}; +use crate::{Db, Entry, Token}; use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; use nohash_hasher::IntMap; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; +use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient}; use thiserror::Error; -use tokio::sync::{oneshot, Notify}; +use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; use tracing::instrument; /// Batcher @@ -51,14 +53,14 @@ impl Batcher { Self { db, shared } } - /// Add a new request to the database and return a future that will generate the text - pub(crate) async fn infer( + /// Add a new request to the database and return a stream of tokens + pub(crate) fn infer_stream( &self, input_length: usize, request: GenerateRequest, - ) -> Result { - // One shot channel to communicate with the background batching task - let (response_tx, response_rx) = oneshot::channel(); + ) -> UnboundedReceiverStream> { + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); // Try to append the request to the database self.db.append(Entry { @@ -73,12 +75,45 @@ impl Batcher { // to be batched self.shared.batching_task.notify_one(); - // Await on the response from the background task - // We can safely unwrap as the background task will never drop the sender - response_rx - .await - .unwrap() - .map_err(|err| InferError::GenerationError(err.to_string())) + // Return stream + UnboundedReceiverStream::new(response_rx) + } + + pub(crate) async fn infer( + &self, + input_length: usize, + request: GenerateRequest, + ) -> Result { + let mut stream = self.infer_stream(input_length, request); + + let mut result_tokens = Vec::new(); + let mut result_generated_text = None; + let mut result_start = None; + let mut result_queued = None; + + while let Some(response) = stream.next().await { + match response? { + InferStreamResponse::Prefill(prefill_tokens) => { + result_tokens.extend(prefill_tokens) + } + InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::End { + generated_text, + start, + queued, + } => { + result_generated_text = Some(generated_text); + result_start = Some(start); + result_queued = Some(queued) + } + } + } + Ok(InferResponse { + tokens: result_tokens, + generated_text: result_generated_text.unwrap(), + queued: result_queued.unwrap(), + start: result_start.unwrap(), + }) } } @@ -106,7 +141,7 @@ async fn batching_task( // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) { - let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await; + let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut waiting_tokens = 1; // We loop until we do not receive any cached batch from the inference server (== until @@ -132,7 +167,7 @@ async fn batching_task( { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - wrap_future(client.generate(new_batch), &mut new_entries).await; + wrap_future(client.prefill(new_batch), &mut new_entries).await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -143,7 +178,7 @@ async fn batching_task( } } - cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await; + cached_batch = wrap_future(client.decode(batches), &mut entries).await; waiting_tokens += 1; } } @@ -152,12 +187,12 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( - future: impl Future, Option), ClientError>>, + future: impl Future, Option), ClientError>>, entries: &mut IntMap, ) -> Option { match future.await { - Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, entries); + Ok((generations, next_batch)) => { + send_generated(generations, entries); next_batch } // If we have an error, we discard the whole batch @@ -172,47 +207,79 @@ async fn wrap_future( fn send_error(error: ClientError, entries: &mut IntMap) { entries.drain().for_each(|(_, entry)| { // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Err(error.clone())).unwrap_or(()); + entry + .response_tx + .send(Err(InferError::GenerationError(error.to_string()))) + .unwrap_or(()); }); } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, entries: &mut IntMap) { - finished.into_iter().for_each(|output| { - // We can `expect` here as the request id should always be in the entries +fn send_generated(generations: Vec, entries: &mut IntMap) { + generations.into_iter().for_each(|generation| { let entry = entries - .remove(&output.request.unwrap().id) + .get(&generation.request_id) .expect("ID not found in entries. This is a bug."); - let response = InferResponse { - output_text: output.output_text, - generated_tokens: output.generated_tokens, - token_ids: output.token_ids, - tokens: output.tokens, - logprobs: output.logprobs, - finish_reason: output.finish_reason, - seed: output.seed, - queued: entry.time, - start: entry.batch_time.unwrap(), // unwrap is always valid - end: Instant::now(), - }; - // unwrap_or is valid here as we don't care if the receiver is gone. - entry.response_tx.send(Ok(response)).unwrap_or(()); + if let Some(prefill_tokens) = generation.prefill_tokens { + let tokens = prefill_tokens + .ids + .into_iter() + .zip(prefill_tokens.logprobs.into_iter()) + .zip(prefill_tokens.texts.into_iter()) + .map(|((id, logprob), text)| Token(id, text, logprob)) + .collect(); + entry + .response_tx + .send(Ok(InferStreamResponse::Prefill(tokens))) + .unwrap_or(()); + } + + let token = Token( + generation.token_id, + generation.token_text, + generation.token_logprob, + ); + entry + .response_tx + .send(Ok(InferStreamResponse::Token(token))) + .unwrap_or(()); + + if let Some(generated_text) = generation.generated_text { + let entry = entries + .remove(&generation.request_id) + .expect("ID not found in entries. This is a bug."); + + entry + .response_tx + .send(Ok(InferStreamResponse::End { + generated_text, + queued: entry.time, + start: entry.batch_time.unwrap(), + })) + .unwrap_or(()); + } }); } +#[derive(Debug)] +pub(crate) enum InferStreamResponse { + Prefill(Vec), + Token(Token), + End { + generated_text: GeneratedText, + start: Instant, + queued: Instant, + }, +} + #[derive(Debug)] pub(crate) struct InferResponse { - pub(crate) output_text: String, - pub(crate) generated_tokens: u32, - pub(crate) token_ids: Vec, - pub(crate) tokens: Vec, - pub(crate) logprobs: Vec, - pub(crate) finish_reason: String, - pub(crate) seed: Option, + pub(crate) tokens: Vec, + pub(crate) generated_text: GeneratedText, + pub(crate) seed: Option pub(crate) queued: Instant, pub(crate) start: Instant, - pub(crate) end: Instant, } #[derive(Debug, Error)] diff --git a/router/src/db.rs b/router/src/db.rs index 15007b64..1be63b6b 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,14 +1,15 @@ +use crate::batcher::InferError; /// This code is massively inspired by Tokio mini-redis -use crate::InferResponse; +use crate::batcher::InferStreamResponse; use crate::{GenerateParameters, GenerateRequest}; use nohash_hasher::{BuildNoHashHasher, IntMap}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; -use tokio::sync::oneshot::Sender; +use tokio::sync::mpsc::UnboundedSender; use tokio::time::Instant; /// Database entry @@ -17,7 +18,7 @@ pub(crate) struct Entry { /// Request pub request: GenerateRequest, /// Response sender to communicate between the Batcher and the batching_task - pub response_tx: Sender>, + pub response_tx: UnboundedSender>, /// Number of tokens in the input pub input_length: usize, /// Instant when this entry was created diff --git a/router/src/lib.rs b/router/src/lib.rs index 1aeac302..aab253b7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -4,7 +4,7 @@ mod db; pub mod server; mod validation; -use batcher::{Batcher, InferResponse}; +use batcher::Batcher; use db::{Db, Entry}; use serde::{Deserialize, Serialize}; use validation::Validation; @@ -69,12 +69,15 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Debug, Serialize)] +pub struct Token(u32, String, f32); + #[derive(Serialize)] pub(crate) struct Details { pub finish_reason: String, pub generated_tokens: u32, pub seed: Option, - pub tokens: Vec<(u32, String, f32)>, + pub tokens: Vec, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 86041b96..cda38fbe 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,11 +1,14 @@ +use crate::batcher::InferStreamResponse; use crate::{ Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; +use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use axum::routing::{get, post}; -use axum::{Json, Router}; +use axum::{BoxError, Json, Router}; +use futures::Stream; use std::net::SocketAddr; use std::sync::Arc; use text_generation_client::ShardedClient; @@ -13,6 +16,7 @@ use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; use tokio::time::Instant; +use tokio_stream::StreamExt; use tracing::instrument; // Server shared state @@ -111,21 +115,12 @@ async fn generate( // Token details let details = match details { - true => { - let tokens = response - .token_ids - .into_iter() - .zip(response.tokens.into_iter()) - .zip(response.logprobs.into_iter()) - .map(|((id, text), logprob)| (id, text, logprob)) - .collect(); - Some(Details { - seed: response.seed, - finish_reason: response.finish_reason, - generated_tokens: response.generated_tokens, - tokens, - }) - } + true => Some(Details { + finish_reason: response.generated_text.finish_reason, + generated_tokens: response.generated_text.generated_tokens, + tokens: response.tokens, + seed: response.seed, + }), false => None, }; @@ -133,8 +128,8 @@ async fn generate( let total_time = start_time.elapsed(); let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; - let inference_time = response.end - response.start; - let time_per_token = inference_time / response.generated_tokens; + let inference_time = Instant::now() - response.start; + let time_per_token = inference_time / response.generated_text.generated_tokens; // Headers let mut headers = HeaderMap::new(); @@ -166,16 +161,57 @@ async fn generate( tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::Span::current().record("seed", format!("{:?}", response.seed)); - tracing::info!("Output: {}", response.output_text); + tracing::info!("Output: {}", response.generated_text.text); // Send response let response = vec![GeneratedText { - generated_text: response.output_text, + generated_text: response.generated_text.text, details, }]; Ok((headers, Json(response))) } +async fn generate_stream( + state: Extension, + req: Json, +) -> Sse>> { + let stream = async_stream::stream! { + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(| err | { + tracing::error!("Model is overloaded"); + err + })?; + + // Validate request + let (input_length, validated_request) = + state.validation.validate(req.0).await.map_err(|err| { + tracing::error!("{}", err); + err + })?; + + // Inference + let mut response_stream = state + .batcher + .infer_stream(input_length, validated_request); + + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + if let InferStreamResponse::Token(token) = response { + yield Ok(Event::default().json_data(token).unwrap()); + } + } + Err(err) => { + tracing::error!("{}", err.to_string()); + yield Ok(Event::default().data(err.to_string())); + } + } + } + }; + + Sse::new(stream).keep_alive(KeepAlive::default()) +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -201,6 +237,7 @@ pub async fn run( let app = Router::new() .route("/", post(generate)) .route("/generate", post(generate)) + .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) .layer(Extension(shared_state.clone())); diff --git a/server/Makefile b/server/Makefile index 6961178b..82fff0db 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,6 @@ gen-server: # Compile protos - pip install grpcio-tools==1.49.1 --no-cache-dir + #pip install grpcio-tools==1.49.1 --no-cache-dir mkdir text_generation/pb || true python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index ccd4c3ba..3f55d271 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch +from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -23,7 +23,6 @@ class CausalLMBatch(Batch): # All tokens all_input_ids: List[torch.Tensor] - all_logprobs: List[Optional[torch.Tensor]] # Lengths of all generations present in the batch input_lengths: List[int] @@ -48,16 +47,15 @@ class CausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] input_lengths = [] - all_logprobs = [] # Parse batch for r in pb.requests: @@ -67,7 +65,6 @@ class CausalLMBatch(Batch): stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) - all_logprobs.append(None) pad_to_multiple_of = 8 if device.type == "cuda" else None tokenized_inputs = tokenizer( @@ -89,7 +86,6 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=None, all_input_ids=all_input_ids, - all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -107,7 +103,6 @@ class CausalLMBatch(Batch): requests = [] input_lengths = [] all_input_ids = [] - all_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -124,7 +119,6 @@ class CausalLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) - all_logprobs.extend(batch.all_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -151,8 +145,8 @@ class CausalLMBatch(Batch): # We need to slice the attention mask to remove padding from previous steps attention_mask[ - start_index:end_index, -batch.max_sequence_length : - ] = batch.attention_mask[:, -batch.max_sequence_length :] + start_index:end_index, -batch.max_sequence_length: + ] = batch.attention_mask[:, -batch.max_sequence_length:] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -198,22 +192,22 @@ class CausalLMBatch(Batch): # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1) :, - :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, + :, + -(batch.max_sequence_length - 1):, + :, + ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1):, + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, :, -(batch.max_sequence_length - 1):, : + ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] start_index += batch.size @@ -225,7 +219,6 @@ class CausalLMBatch(Batch): position_ids=position_ids, past_key_values=past_key_values, all_input_ids=all_input_ids, - all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -234,6 +227,9 @@ class CausalLMBatch(Batch): keys_head_dim_last=batches[0].keys_head_dim_last, ) + def __len__(self): + return len(self.requests) + class CausalLM(Model): def __init__(self, model_name: str, quantize=False): @@ -275,7 +271,7 @@ class CausalLM(Model): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -288,8 +284,8 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values def generate_token( - self, batch: CausalLMBatch - ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: + self, batch: CausalLMBatch + ) -> Tuple[List[Generation], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( torch.no_grad if self.device.type == "cpu" else torch.inference_mode @@ -309,14 +305,13 @@ class CausalLM(Model): next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] - next_batch_all_logprobs = [] # Metadata next_batch_size = 0 next_batch_max_sequence_length = 0 - # Finished requests - generated_texts: List[GeneratedText] = [] + # Results + results = [] # Zipped iterator iterator = zip( @@ -326,55 +321,42 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, - batch.all_logprobs, ) # For each member of the batch for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_logprobs, + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token tokens, logprobs = next_token_chooser(all_input_ids, logits) - next_token = tokens[-1].view(1, 1) + next_token_id = tokens[-1].view(1, 1) # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token]) + all_input_ids = torch.cat([all_input_ids, next_token_id]) new_input_length = input_length + 1 - if all_logprobs is None: - # logprobs of all prompt tokens (except the first one) and the generated token - all_logprobs = logprobs.gather(1, all_input_ids[1:]) - else: - # logprob of the generated token - next_token_logprob = logprobs[-1, next_token] - all_logprobs = torch.cat([all_logprobs, next_token_logprob]) + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text = self.decode(next_token_id.squeeze()) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token.squeeze(), - self.tokenizer.decode( - next_token.squeeze(), clean_up_tokenization_spaces=False - ), + next_token_id_squeezed, + next_token_text, ) + if stop: # Decode generated tokens generated_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + all_input_ids[-stopping_criteria.current_tokens:, 0] ) output_text = request.inputs + generated_text - # Slice with input_length to remove padding - token_ids = all_input_ids[-new_input_length:] - tokens = self.tokenizer.batch_decode(token_ids) - # Add NaN for the first prompt token - logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze( - 1 - ).tolist() # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -382,39 +364,48 @@ class CausalLM(Model): else: seed = None - # Add to the list of finished generations with the original request - generated_texts.append( - GeneratedText( - request=request, - output_text=output_text, - generated_tokens=stopping_criteria.current_tokens, - tokens=tokens, - token_ids=token_ids.squeeze(1).tolist(), - logprobs=logprobs, - reason=reason, - seed=seed, - ) - ) - # add to the next batch + generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed) else: + # Keep request in the batch + generated_text = None next_batch_keep_indices.append(i) - next_batch_input_ids.append(next_token) + next_batch_input_ids.append(next_token_id) next_batch_all_input_ids.append(all_input_ids) - next_batch_all_logprobs.append(all_logprobs) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length ) + # Prefill + if stopping_criteria.current_tokens == 0: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[ + -new_input_length:-1]).squeeze( + 1).tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode(prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False) + prefill_tokens = PrefillTokens(prefill_token_ids, + prefill_logprobs, + prefill_texts) + else: + prefill_tokens = None + + result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text, + generated_text) + + results.append(result) + # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generated_texts, None + return results, None next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # If we finished at least one generation, we need to evict the indices of the generations that finished # from the values of the next batch - if generated_texts: + if len(next_batch_keep_indices) != len(batch): # Apply indices to attention mask, past key values and other items that need to be cached next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_position_ids = batch.position_ids[next_batch_keep_indices] @@ -461,7 +452,6 @@ class CausalLM(Model): position_ids=next_batch_position_ids, past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, - all_logprobs=next_batch_all_logprobs, input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, @@ -469,4 +459,4 @@ class CausalLM(Model): max_sequence_length=next_batch_max_sequence_length, keys_head_dim_last=batch.keys_head_dim_last, ) - return generated_texts, next_batch + return results, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 4ee3cb32..0ad8cc87 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -17,10 +17,10 @@ class Batch(ABC): @classmethod @abstractmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "Batch": raise NotImplementedError @@ -32,23 +32,49 @@ class Batch(ABC): @dataclass class GeneratedText: - request: generate_pb2.Request - output_text: str + text: str generated_tokens: int - tokens: List[str] - token_ids: List[int] - logprobs: List[float] - reason: str + finish_reason: str seed: Optional[int] def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( - request=self.request, - output_text=self.output_text, + text=self.text, generated_tokens=self.generated_tokens, - tokens=self.tokens, - token_ids=self.token_ids, - logprobs=self.logprobs, - finish_reason=self.reason, - seed=self.seed, + finish_reason=self.finish_reason + seed=self.seed + ) + + +@dataclass +class PrefillTokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + + def to_pb(self) -> generate_pb2.PrefillTokens: + return generate_pb2.PrefillTokens( + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts + ) + + +@dataclass +class Generation: + request_id: int + prefill_tokens: Optional[PrefillTokens] + token_id: int + token_logprob: float + token_text: str + generated_text: Optional[GeneratedText] + + def to_pb(self) -> generate_pb2.Generation: + return generate_pb2.Generation( + request_id=self.request_id, + prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, + token_id=self.token_id, + token_logprob=self.token_logprob, + token_text=self.token_text, + generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, ) diff --git a/server/text_generation/server.py b/server/text_generation/server.py index 5fd3072e..1cf8de95 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -27,22 +27,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.cache.clear() return generate_pb2.ClearCacheResponse() - async def Generate(self, request, context): + async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.device ) - generated_texts, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.GenerateResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts + return generate_pb2.PrefillResponse( + generations=[ + generation.to_pb() for generation in generations ], batch=next_batch.to_pb() if next_batch else None, ) - async def GenerateWithCache(self, request, context): + async def Decode(self, request, context): if len(request.batches) == 0: raise ValueError("Must provide at least one batch") @@ -58,12 +58,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] - generated_texts, next_batch = self.model.generate_token(batch) + generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) - return generate_pb2.GenerateWithCacheResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts + return generate_pb2.DecodeResponse( + generations=[ + generation.to_pb() for generation in generations ], batch=next_batch.to_pb() if next_batch else None, )