added streaming endpoint

This commit is contained in:
Yannic Kilcher 2023-01-26 14:50:57 +01:00
parent 5c01e2544c
commit b1ef80583c
13 changed files with 214 additions and 28 deletions

2
Cargo.lock generated
View File

@ -1795,6 +1795,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"futures", "futures",
"prost", "prost",
"serde",
"thiserror", "thiserror",
"tokio", "tokio",
"tonic", "tonic",
@ -1823,6 +1824,7 @@ dependencies = [
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-stream",
"axum", "axum",
"clap 4.0.22", "clap 4.0.22",
"futures", "futures",

View File

@ -84,6 +84,13 @@ message GeneratedText {
string finish_reason = 7; string finish_reason = 7;
} }
message Intermediate {
uint64 request_id = 1;
uint32 token_id = 2;
float logprob = 3;
string token = 4;
}
message GenerateRequest { message GenerateRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
@ -94,6 +101,8 @@ message GenerateResponse {
repeated GeneratedText generated_texts = 1; repeated GeneratedText generated_texts = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
repeated Intermediate intermediates = 3;
} }
message GenerateWithCacheRequest { message GenerateWithCacheRequest {
@ -106,4 +115,6 @@ message GenerateWithCacheResponse {
repeated GeneratedText generated_texts = 1; repeated GeneratedText generated_texts = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
repeated Intermediate intermediates = 3;
} }

View File

@ -22,7 +22,8 @@ serde = "1.0.145"
serde_json = "1.0.85" serde_json = "1.0.85"
thiserror = "1.0.37" thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] }
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-subscriber = { version = "0.3.15", features = ["json"] }
async-stream = "0.3.3"

View File

@ -12,6 +12,7 @@ tonic = "^0.6"
tower = "^0.4" tower = "^0.4"
tracing = "^0.1" tracing = "^0.1"
tracing-error = "^0.2" tracing-error = "^0.2"
serde = { version = "1.0", features = ["derive"] }
[build-dependencies] [build-dependencies]
tonic-build = "0.6.2" tonic-build = "0.6.2"

View File

@ -73,7 +73,7 @@ impl Client {
/// Returns a list of generated texts of request that met their stopping criteria /// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> { pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
let response = self let response = self
.stub .stub
@ -81,7 +81,7 @@ impl Client {
.instrument(info_span!("generate")) .instrument(info_span!("generate"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch)) Ok((response.generated_texts, response.batch, response.intermediates))
} }
/// Generate one token for each request in the given cached batch /// Generate one token for each request in the given cached batch
@ -92,7 +92,7 @@ impl Client {
pub async fn generate_with_cache( pub async fn generate_with_cache(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> { ) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateWithCacheRequest { batches }); let request = tonic::Request::new(GenerateWithCacheRequest { batches });
let response = self let response = self
.stub .stub
@ -100,6 +100,6 @@ impl Client {
.instrument(info_span!("generate_with_cache")) .instrument(info_span!("generate_with_cache"))
.await? .await?
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch)) Ok((response.generated_texts, response.batch, response.intermediates))
} }
} }

View File

@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
@ -35,3 +35,10 @@ impl From<transport::Error> for ClientError {
} }
pub type Result<T> = std::result::Result<T, ClientError>; pub type Result<T> = std::result::Result<T, ClientError>;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IntermediateEvent {
pub token: String,
pub token_id: u32,
pub logprob: f32,
}

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, GeneratedText}; use crate::{Batch, Client, GeneratedText, Intermediate};
use futures::future::join_all; use futures::future::join_all;
use futures::future::select_all; use futures::future::select_all;
use tonic::transport::Uri; use tonic::transport::Uri;
@ -41,7 +41,7 @@ impl ShardedClient {
/// ///
/// Returns a list of generated texts of request that met their stopping criteria /// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch /// and the next cached batch
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> { pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
@ -59,7 +59,7 @@ impl ShardedClient {
pub async fn generate_with_cache( pub async fn generate_with_cache(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> { ) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()

View File

@ -5,9 +5,9 @@ use axum::http::StatusCode;
use axum::Json; use axum::Json;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify, mpsc};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
@ -50,6 +50,36 @@ impl Batcher {
Self { db, shared } Self { db, shared }
} }
/// Add a new request to the database and return a future that will generate the text
pub(crate) fn infer_stream(
&self,
input_length: usize,
request: GenerateRequest,
intermediate_tx: mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>,
response_tx: oneshot::Sender<Result<InferResponse, ClientError>>,
) {
// Try to append the request to the database
self.db.append(Entry {
request,
response_tx,
intermediate_tx: Some(intermediate_tx),
input_length,
time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_one();
// // Await on the response from the background task
// // We can safely unwrap as the background task will never drop the sender
// response_rx
// .await
// .unwrap()
// .map_err(|err| InferError::GenerationError(err.to_string()))
}
/// Add a new request to the database and return a future that will generate the text /// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer( pub(crate) async fn infer(
&self, &self,
@ -63,6 +93,7 @@ impl Batcher {
self.db.append(Entry { self.db.append(Entry {
request, request,
response_tx, response_tx,
intermediate_tx: None,
input_length, input_length,
time: Instant::now(), time: Instant::now(),
batch_time: None, batch_time: None,
@ -153,13 +184,13 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher /// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>), ClientError>>,
request_ids: Vec<u64>, request_ids: Vec<u64>,
db: &Db, db: &Db,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { match future.await {
Ok((generated_texts, next_batch)) => { Ok((generated_texts, next_batch, intermediates)) => {
send_generated(generated_texts, db); send_generated(generated_texts, intermediates, db);
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
@ -181,13 +212,30 @@ fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
} }
/// Send `generated_text` to the Batcher for all `finished` /// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) { fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>, db: &Db) {
// zip with intermediates
intermediates.into_iter().for_each(|intermediate| {
// We can `expect` here as the request id should always be in the DB
let guard = db.get_mutex_guard();
let entry = guard.entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug.");
if let Some(tx) = &entry.intermediate_tx {
// unwrap_or is valid here as we don't care if the receiver is gone.
tx.send(Ok(Some(intermediate))).unwrap_or(());
}
});
finished.into_iter().for_each(|output| { finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the DB // We can `expect` here as the request id should always be in the DB
let entry = db let entry = db
.remove(&output.request.unwrap().id) .remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug."); .expect("ID not found in db. This is a bug.");
if let Some(tx) = &entry.intermediate_tx {
tx.send(Ok(None)).unwrap_or(());
}
let response = InferResponse { let response = InferResponse {
output_text: output.output_text, output_text: output.output_text,
generated_tokens: output.generated_tokens, generated_tokens: output.generated_tokens,

View File

@ -5,7 +5,7 @@ use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
}; };
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::Instant; use tokio::time::Instant;
@ -17,6 +17,8 @@ pub(crate) struct Entry {
pub request: GenerateRequest, pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task /// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<InferResponse, ClientError>>, pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Intermediate sender to communicate between the Batcher and the batching_task
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
/// Number of tokens in the input /// Number of tokens in the input
pub input_length: usize, pub input_length: usize,
/// Instant when this entry was created /// Instant when this entry was created
@ -39,9 +41,9 @@ pub struct Shared {
/// Database State /// Database State
#[derive(Debug)] #[derive(Debug)]
struct State { pub(crate) struct State {
/// Database entries organized in a BTreeMap to be able to iterate over them in order /// Database entries organized in a BTreeMap to be able to iterate over them in order
entries: BTreeMap<u64, Entry>, pub(crate) entries: BTreeMap<u64, Entry>,
/// Id of the next entry /// Id of the next entry
next_id: u64, next_id: u64,
@ -118,6 +120,10 @@ impl Db {
state.entries.remove(id) state.entries.remove(id)
} }
pub(crate) fn get_mutex_guard(&self) -> parking_lot::MutexGuard<State> {
self.shared.state.lock()
}
// Get the next batch // Get the next batch
pub(crate) fn next_batch( pub(crate) fn next_batch(
&self, &self,

View File

@ -8,12 +8,17 @@ use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::ShardedClient; use text_generation_client::{ShardedClient, IntermediateEvent};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
use tokio::sync::{oneshot, mpsc};
use axum::response::sse::{Event, KeepAlive, Sse};
use std::convert::Infallible;
use futures::stream::Stream;
// Server shared state // Server shared state
#[derive(Clone)] #[derive(Clone)]
@ -62,6 +67,80 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
Ok(()) Ok(())
} }
#[derive(serde::Serialize)]
struct StreamEvent {
is_end: bool,
event: Option<IntermediateEvent>,
generated_text: Option<GeneratedText>,
}
async fn generate_stream(
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = oneshot::channel();
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
}).unwrap();
// Inference
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx);
let stream = async_stream::stream! {
while let Some(item) = intermediate_rx.recv().await {
match item {
Ok(item) => {
match item {
Some(item) => {
let event_data = IntermediateEvent {
token: item.token,
token_id: item.token_id,
logprob: item.logprob,
};
let stream_event = StreamEvent {
is_end: false,
event: Some(event_data),
generated_text: None,
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
}
None => {
break
}
}
}
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
}
let response = response_rx.await.unwrap();
match response {
Ok(response) => {
let response = GeneratedText {
generated_text: response.output_text,
details: None,
};
let stream_event = StreamEvent {
is_end: true,
event: None,
generated_text: Some(response),
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
}
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
/// Generate method /// Generate method
#[instrument( #[instrument(
skip(state), skip(state),
@ -197,6 +276,7 @@ pub async fn run(
let app = Router::new() let app = Router::new()
.route("/", post(generate)) .route("/", post(generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(shared_state.clone())); .layer(Extension(shared_state.clone()));

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch, Intermediate
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria
@ -314,6 +314,7 @@ class CausalLM(Model):
# Finished requests # Finished requests
generated_texts: List[GeneratedText] = [] generated_texts: List[GeneratedText] = []
intermediates: List[Intermediate] = []
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -352,12 +353,23 @@ class CausalLM(Model):
next_token_logprob = logprobs[-1, next_token] next_token_logprob = logprobs[-1, next_token]
all_logprobs = torch.cat([all_logprobs, next_token_logprob]) all_logprobs = torch.cat([all_logprobs, next_token_logprob])
next_token_sq = next_token.squeeze()
decoded_next_token = self.tokenizer.decode(
next_token_sq, clean_up_tokenization_spaces=False
)
next_token_logprob = all_logprobs[-1]
intermediate = Intermediate(
request_id=request.id,
token=decoded_next_token,
logprob=next_token_logprob.item(),
token_id=next_token_sq.item(),
)
intermediates.append(intermediate.to_pb())
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token.squeeze(), next_token_sq,
self.tokenizer.decode( decoded_next_token,
next_token.squeeze(), clean_up_tokenization_spaces=False
),
) )
if stop: if stop:
# Decode generated tokens # Decode generated tokens
@ -399,7 +411,7 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generated_texts, None, intermediates
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0) next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
@ -459,4 +471,4 @@ class CausalLM(Model):
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )
return generated_texts, next_batch return generated_texts, next_batch, intermediates

View File

@ -50,3 +50,19 @@ class GeneratedText:
logprobs=self.logprobs, logprobs=self.logprobs,
finish_reason=self.reason, finish_reason=self.reason,
) )
@dataclass
class Intermediate:
request_id: int
token_id: int
logprob: float
token: str
def to_pb(self) -> generate_pb2.Intermediate:
return generate_pb2.Intermediate(
request_id = self.request_id,
token_id = self.token_id,
logprob = self.logprob,
token = self.token,
)

View File

@ -32,7 +32,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.device request.batch, self.model.tokenizer, self.model.device
) )
generated_texts, next_batch = self.model.generate_token(batch) generated_texts, next_batch, intermediates = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateResponse( return generate_pb2.GenerateResponse(
@ -40,6 +40,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generated_text.to_pb() for generated_text in generated_texts generated_text.to_pb() for generated_text in generated_texts
], ],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
intermediates=intermediates,
) )
async def GenerateWithCache(self, request, context): async def GenerateWithCache(self, request, context):
@ -58,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else: else:
batch = batches[0] batch = batches[0]
generated_texts, next_batch = self.model.generate_token(batch) generated_texts, next_batch, intermediates = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.GenerateWithCacheResponse( return generate_pb2.GenerateWithCacheResponse(
@ -66,6 +67,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generated_text.to_pb() for generated_text in generated_texts generated_text.to_pb() for generated_text in generated_texts
], ],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
intermediates=intermediates,
) )