From b1ef80583c925ff7e60da4e7cb7d199d348962b1 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 14:50:57 +0100 Subject: [PATCH] added streaming endpoint --- Cargo.lock | 2 + proto/generate.proto | 11 +++ router/Cargo.toml | 3 +- router/client/Cargo.toml | 1 + router/client/src/client.rs | 8 +-- router/client/src/lib.rs | 9 ++- router/client/src/sharded_client.rs | 6 +- router/src/batcher.rs | 60 ++++++++++++++-- router/src/db.rs | 12 +++- router/src/server.rs | 82 +++++++++++++++++++++- server/text_generation/models/causal_lm.py | 26 +++++-- server/text_generation/models/types.py | 16 +++++ server/text_generation/server.py | 6 +- 13 files changed, 214 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 752c4886..3ae40755 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1795,6 +1795,7 @@ version = "0.1.0" dependencies = [ "futures", "prost", + "serde", "thiserror", "tokio", "tonic", @@ -1823,6 +1824,7 @@ dependencies = [ name = "text-generation-router" version = "0.1.0" dependencies = [ + "async-stream", "axum", "clap 4.0.22", "futures", diff --git a/proto/generate.proto b/proto/generate.proto index 16539f8b..025f1aba 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -84,6 +84,13 @@ message GeneratedText { string finish_reason = 7; } +message Intermediate { + uint64 request_id = 1; + uint32 token_id = 2; + float logprob = 3; + string token = 4; +} + message GenerateRequest { /// Batch Batch batch = 1; @@ -94,6 +101,8 @@ message GenerateResponse { repeated GeneratedText generated_texts = 1; /// Next batch (cached) optional Batch batch = 2; + + repeated Intermediate intermediates = 3; } message GenerateWithCacheRequest { @@ -106,4 +115,6 @@ message GenerateWithCacheResponse { repeated GeneratedText generated_texts = 1; /// Next batch (cached) optional Batch batch = 2; + + repeated Intermediate intermediates = 3; } diff --git a/router/Cargo.toml b/router/Cargo.toml index f99069d3..2a51773d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,7 +22,8 @@ serde = "1.0.145" serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" -tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["json"] } +async-stream = "0.3.3" diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index fdd32494..d8a61562 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -12,6 +12,7 @@ tonic = "^0.6" tower = "^0.4" tracing = "^0.1" tracing-error = "^0.2" +serde = { version = "1.0", features = ["derive"] } [build-dependencies] tonic-build = "0.6.2" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 172d0bf7..47876298 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -73,7 +73,7 @@ impl Client { /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch #[instrument(skip(self))] - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option, Vec)> { let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let response = self .stub @@ -81,7 +81,7 @@ impl Client { .instrument(info_span!("generate")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generated_texts, response.batch, response.intermediates)) } /// Generate one token for each request in the given cached batch @@ -92,7 +92,7 @@ impl Client { pub async fn generate_with_cache( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, Vec)> { let request = tonic::Request::new(GenerateWithCacheRequest { batches }); let response = self .stub @@ -100,6 +100,6 @@ impl Client { .instrument(info_span!("generate_with_cache")) .await? .into_inner(); - Ok((response.generated_texts, response.batch)) + Ok((response.generated_texts, response.batch, response.intermediates)) } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 295b009b..13cc196b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, }; pub use sharded_client::ShardedClient; use thiserror::Error; @@ -35,3 +35,10 @@ impl From for ClientError { } pub type Result = std::result::Result; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct IntermediateEvent { + pub token: String, + pub token_id: u32, + pub logprob: f32, +} \ No newline at end of file diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 6c70afca..e8a939e8 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,6 +1,6 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, GeneratedText}; +use crate::{Batch, Client, GeneratedText, Intermediate}; use futures::future::join_all; use futures::future::select_all; use tonic::transport::Uri; @@ -41,7 +41,7 @@ impl ShardedClient { /// /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch - pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option, Vec)> { let futures: Vec<_> = self .clients .iter_mut() @@ -59,7 +59,7 @@ impl ShardedClient { pub async fn generate_with_cache( &mut self, batches: Vec, - ) -> Result<(Vec, Option)> { + ) -> Result<(Vec, Option, Vec)> { let futures: Vec<_> = self .clients .iter_mut() diff --git a/router/src/batcher.rs b/router/src/batcher.rs index ee83d899..a4059d13 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -5,9 +5,9 @@ use axum::http::StatusCode; use axum::Json; use std::future::Future; use std::sync::Arc; -use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; +use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate}; use thiserror::Error; -use tokio::sync::{oneshot, Notify}; +use tokio::sync::{oneshot, Notify, mpsc}; use tokio::time::Instant; use tracing::instrument; @@ -50,6 +50,36 @@ impl Batcher { Self { db, shared } } + /// Add a new request to the database and return a future that will generate the text + pub(crate) fn infer_stream( + &self, + input_length: usize, + request: GenerateRequest, + intermediate_tx: mpsc::UnboundedSender, ClientError>>, + response_tx: oneshot::Sender>, + ) { + // Try to append the request to the database + self.db.append(Entry { + request, + response_tx, + intermediate_tx: Some(intermediate_tx), + input_length, + time: Instant::now(), + batch_time: None, + }); + + // Notify the background task that we have a new entry in the database that needs + // to be batched + self.shared.batching_task.notify_one(); + + // // Await on the response from the background task + // // We can safely unwrap as the background task will never drop the sender + // response_rx + // .await + // .unwrap() + // .map_err(|err| InferError::GenerationError(err.to_string())) + } + /// Add a new request to the database and return a future that will generate the text pub(crate) async fn infer( &self, @@ -63,6 +93,7 @@ impl Batcher { self.db.append(Entry { request, response_tx, + intermediate_tx: None, input_length, time: Instant::now(), batch_time: None, @@ -153,13 +184,13 @@ async fn batching_task( /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( - future: impl Future, Option), ClientError>>, + future: impl Future, Option, Vec), ClientError>>, request_ids: Vec, db: &Db, ) -> Option { match future.await { - Ok((generated_texts, next_batch)) => { - send_generated(generated_texts, db); + Ok((generated_texts, next_batch, intermediates)) => { + send_generated(generated_texts, intermediates, db); next_batch } // If we have an error, we discard the whole batch @@ -181,12 +212,29 @@ fn send_error(error: ClientError, request_ids: Vec, db: &Db) { } /// Send `generated_text` to the Batcher for all `finished` -fn send_generated(finished: Vec, db: &Db) { +fn send_generated(finished: Vec, intermediates: Vec, db: &Db) { + // zip with intermediates + intermediates.into_iter().for_each(|intermediate| { + // We can `expect` here as the request id should always be in the DB + let guard = db.get_mutex_guard(); + let entry = guard.entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug."); + + + if let Some(tx) = &entry.intermediate_tx { + // unwrap_or is valid here as we don't care if the receiver is gone. + tx.send(Ok(Some(intermediate))).unwrap_or(()); + } + }); + finished.into_iter().for_each(|output| { // We can `expect` here as the request id should always be in the DB let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); + + if let Some(tx) = &entry.intermediate_tx { + tx.send(Ok(None)).unwrap_or(()); + } let response = InferResponse { output_text: output.output_text, diff --git a/router/src/db.rs b/router/src/db.rs index 1d7df627..1c27fc4d 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -5,7 +5,7 @@ use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, + Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate, }; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -17,6 +17,8 @@ pub(crate) struct Entry { pub request: GenerateRequest, /// Response sender to communicate between the Batcher and the batching_task pub response_tx: Sender>, + /// Intermediate sender to communicate between the Batcher and the batching_task + pub intermediate_tx: Option, ClientError>>>, /// Number of tokens in the input pub input_length: usize, /// Instant when this entry was created @@ -39,9 +41,9 @@ pub struct Shared { /// Database State #[derive(Debug)] -struct State { +pub(crate) struct State { /// Database entries organized in a BTreeMap to be able to iterate over them in order - entries: BTreeMap, + pub(crate) entries: BTreeMap, /// Id of the next entry next_id: u64, @@ -118,6 +120,10 @@ impl Db { state.entries.remove(id) } + pub(crate) fn get_mutex_guard(&self) -> parking_lot::MutexGuard { + self.shared.state.lock() + } + // Get the next batch pub(crate) fn next_batch( &self, diff --git a/router/src/server.rs b/router/src/server.rs index 623dd07c..14418a28 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,12 +8,17 @@ use axum::routing::{get, post}; use axum::{Json, Router}; use std::net::SocketAddr; use std::sync::Arc; -use text_generation_client::ShardedClient; +use text_generation_client::{ShardedClient, IntermediateEvent}; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; use tokio::time::Instant; use tracing::instrument; +use tokio::sync::{oneshot, mpsc}; + +use axum::response::sse::{Event, KeepAlive, Sse}; +use std::convert::Infallible; +use futures::stream::Stream; // Server shared state #[derive(Clone)] @@ -62,6 +67,80 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json, + generated_text: Option, +} + +async fn generate_stream( + state: Extension, + req: Json, +) -> Sse>> { + let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel(); + let (response_tx, response_rx) = oneshot::channel(); + + let (input_length, validated_request) = + state.validation.validate(req.0).await.map_err(|err| { + tracing::error!("{}", err.to_string()); + err + }).unwrap(); + + // Inference + state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx); + + let stream = async_stream::stream! { + while let Some(item) = intermediate_rx.recv().await { + match item { + Ok(item) => { + match item { + Some(item) => { + let event_data = IntermediateEvent { + token: item.token, + token_id: item.token_id, + logprob: item.logprob, + }; + let stream_event = StreamEvent { + is_end: false, + event: Some(event_data), + generated_text: None, + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + None => { + break + } + } + } + Err(err) => { + yield Ok(Event::default().data(err.to_string())); + } + } + } + let response = response_rx.await.unwrap(); + match response { + Ok(response) => { + let response = GeneratedText { + generated_text: response.output_text, + details: None, + }; + let stream_event = StreamEvent { + is_end: true, + event: None, + generated_text: Some(response), + }; + yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap())); + } + Err(err) => { + yield Ok(Event::default().data(err.to_string())); + } + } + }; + + Sse::new(stream).keep_alive(KeepAlive::default()) +} + /// Generate method #[instrument( skip(state), @@ -197,6 +276,7 @@ pub async fn run( let app = Router::new() .route("/", post(generate)) .route("/generate", post(generate)) + .route("/generate_stream", post(generate_stream)) .route("/", get(health)) .route("/health", get(health)) .layer(Extension(shared_state.clone())); diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 6e35b2ad..71224698 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch +from text_generation.models.types import GeneratedText, Batch, Intermediate from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria @@ -314,6 +314,7 @@ class CausalLM(Model): # Finished requests generated_texts: List[GeneratedText] = [] + intermediates: List[Intermediate] = [] # Zipped iterator iterator = zip( @@ -352,12 +353,23 @@ class CausalLM(Model): next_token_logprob = logprobs[-1, next_token] all_logprobs = torch.cat([all_logprobs, next_token_logprob]) + next_token_sq = next_token.squeeze() + decoded_next_token = self.tokenizer.decode( + next_token_sq, clean_up_tokenization_spaces=False + ) + next_token_logprob = all_logprobs[-1] + intermediate = Intermediate( + request_id=request.id, + token=decoded_next_token, + logprob=next_token_logprob.item(), + token_id=next_token_sq.item(), + ) + intermediates.append(intermediate.to_pb()) + # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token.squeeze(), - self.tokenizer.decode( - next_token.squeeze(), clean_up_tokenization_spaces=False - ), + next_token_sq, + decoded_next_token, ) if stop: # Decode generated tokens @@ -399,7 +411,7 @@ class CausalLM(Model): # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: - return generated_texts, None + return generated_texts, None, intermediates next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) # If we finished at least one generation, we need to evict the indices of the generations that finished @@ -459,4 +471,4 @@ class CausalLM(Model): max_sequence_length=next_batch_max_sequence_length, keys_head_dim_last=batch.keys_head_dim_last, ) - return generated_texts, next_batch + return generated_texts, next_batch, intermediates diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 6bf64e05..1874a8a8 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -50,3 +50,19 @@ class GeneratedText: logprobs=self.logprobs, finish_reason=self.reason, ) + +@dataclass +class Intermediate: + request_id: int + token_id: int + logprob: float + token: str + + def to_pb(self) -> generate_pb2.Intermediate: + return generate_pb2.Intermediate( + request_id = self.request_id, + token_id = self.token_id, + logprob = self.logprob, + token = self.token, + ) + diff --git a/server/text_generation/server.py b/server/text_generation/server.py index 5fd3072e..8efd8c0f 100644 --- a/server/text_generation/server.py +++ b/server/text_generation/server.py @@ -32,7 +32,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.device ) - generated_texts, next_batch = self.model.generate_token(batch) + generated_texts, next_batch, intermediates = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.GenerateResponse( @@ -40,6 +40,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): generated_text.to_pb() for generated_text in generated_texts ], batch=next_batch.to_pb() if next_batch else None, + intermediates=intermediates, ) async def GenerateWithCache(self, request, context): @@ -58,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] - generated_texts, next_batch = self.model.generate_token(batch) + generated_texts, next_batch, intermediates = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.GenerateWithCacheResponse( @@ -66,6 +67,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): generated_text.to_pb() for generated_text in generated_texts ], batch=next_batch.to_pb() if next_batch else None, + intermediates=intermediates, )