added streaming endpoint

This commit is contained in:
Yannic Kilcher 2023-01-26 14:50:57 +01:00
parent 5c01e2544c
commit b1ef80583c
13 changed files with 214 additions and 28 deletions

2
Cargo.lock generated
View File

@ -1795,6 +1795,7 @@ version = "0.1.0"
dependencies = [
"futures",
"prost",
"serde",
"thiserror",
"tokio",
"tonic",
@ -1823,6 +1824,7 @@ dependencies = [
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"async-stream",
"axum",
"clap 4.0.22",
"futures",

View File

@ -84,6 +84,13 @@ message GeneratedText {
string finish_reason = 7;
}
message Intermediate {
uint64 request_id = 1;
uint32 token_id = 2;
float logprob = 3;
string token = 4;
}
message GenerateRequest {
/// Batch
Batch batch = 1;
@ -94,6 +101,8 @@ message GenerateResponse {
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
repeated Intermediate intermediates = 3;
}
message GenerateWithCacheRequest {
@ -106,4 +115,6 @@ message GenerateWithCacheResponse {
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
repeated Intermediate intermediates = 3;
}

View File

@ -22,7 +22,8 @@ serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] }
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] }
async-stream = "0.3.3"

View File

@ -12,6 +12,7 @@ tonic = "^0.6"
tower = "^0.4"
tracing = "^0.1"
tracing-error = "^0.2"
serde = { version = "1.0", features = ["derive"] }
[build-dependencies]
tonic-build = "0.6.2"

View File

@ -73,7 +73,7 @@ impl Client {
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
let response = self
.stub
@ -81,7 +81,7 @@ impl Client {
.instrument(info_span!("generate"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
Ok((response.generated_texts, response.batch, response.intermediates))
}
/// Generate one token for each request in the given cached batch
@ -92,7 +92,7 @@ impl Client {
pub async fn generate_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
let response = self
.stub
@ -100,6 +100,6 @@ impl Client {
.instrument(info_span!("generate_with_cache"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
Ok((response.generated_texts, response.batch, response.intermediates))
}
}

View File

@ -7,7 +7,7 @@ mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
@ -35,3 +35,10 @@ impl From<transport::Error> for ClientError {
}
pub type Result<T> = std::result::Result<T, ClientError>;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct IntermediateEvent {
pub token: String,
pub token_id: u32,
pub logprob: f32,
}

View File

@ -1,6 +1,6 @@
/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, GeneratedText};
use crate::{Batch, Client, GeneratedText, Intermediate};
use futures::future::join_all;
use futures::future::select_all;
use tonic::transport::Uri;
@ -41,7 +41,7 @@ impl ShardedClient {
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
@ -59,7 +59,7 @@ impl ShardedClient {
pub async fn generate_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
let futures: Vec<_> = self
.clients
.iter_mut()

View File

@ -5,9 +5,9 @@ use axum::http::StatusCode;
use axum::Json;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate};
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::sync::{oneshot, Notify, mpsc};
use tokio::time::Instant;
use tracing::instrument;
@ -50,6 +50,36 @@ impl Batcher {
Self { db, shared }
}
/// Add a new request to the database and return a future that will generate the text
pub(crate) fn infer_stream(
&self,
input_length: usize,
request: GenerateRequest,
intermediate_tx: mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>,
response_tx: oneshot::Sender<Result<InferResponse, ClientError>>,
) {
// Try to append the request to the database
self.db.append(Entry {
request,
response_tx,
intermediate_tx: Some(intermediate_tx),
input_length,
time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_one();
// // Await on the response from the background task
// // We can safely unwrap as the background task will never drop the sender
// response_rx
// .await
// .unwrap()
// .map_err(|err| InferError::GenerationError(err.to_string()))
}
/// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer(
&self,
@ -63,6 +93,7 @@ impl Batcher {
self.db.append(Entry {
request,
response_tx,
intermediate_tx: None,
input_length,
time: Instant::now(),
batch_time: None,
@ -153,13 +184,13 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>), ClientError>>,
request_ids: Vec<u64>,
db: &Db,
) -> Option<Batch> {
match future.await {
Ok((generated_texts, next_batch)) => {
send_generated(generated_texts, db);
Ok((generated_texts, next_batch, intermediates)) => {
send_generated(generated_texts, intermediates, db);
next_batch
}
// If we have an error, we discard the whole batch
@ -181,12 +212,29 @@ fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
}
/// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>, db: &Db) {
// zip with intermediates
intermediates.into_iter().for_each(|intermediate| {
// We can `expect` here as the request id should always be in the DB
let guard = db.get_mutex_guard();
let entry = guard.entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug.");
if let Some(tx) = &entry.intermediate_tx {
// unwrap_or is valid here as we don't care if the receiver is gone.
tx.send(Ok(Some(intermediate))).unwrap_or(());
}
});
finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the DB
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
if let Some(tx) = &entry.intermediate_tx {
tx.send(Ok(None)).unwrap_or(());
}
let response = InferResponse {
output_text: output.output_text,

View File

@ -5,7 +5,7 @@ use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
};
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
@ -17,6 +17,8 @@ pub(crate) struct Entry {
pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Intermediate sender to communicate between the Batcher and the batching_task
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created
@ -39,9 +41,9 @@ pub struct Shared {
/// Database State
#[derive(Debug)]
struct State {
pub(crate) struct State {
/// Database entries organized in a BTreeMap to be able to iterate over them in order
entries: BTreeMap<u64, Entry>,
pub(crate) entries: BTreeMap<u64, Entry>,
/// Id of the next entry
next_id: u64,
@ -118,6 +120,10 @@ impl Db {
state.entries.remove(id)
}
pub(crate) fn get_mutex_guard(&self) -> parking_lot::MutexGuard<State> {
self.shared.state.lock()
}
// Get the next batch
pub(crate) fn next_batch(
&self,

View File

@ -8,12 +8,17 @@ use axum::routing::{get, post};
use axum::{Json, Router};
use std::net::SocketAddr;
use std::sync::Arc;
use text_generation_client::ShardedClient;
use text_generation_client::{ShardedClient, IntermediateEvent};
use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant;
use tracing::instrument;
use tokio::sync::{oneshot, mpsc};
use axum::response::sse::{Event, KeepAlive, Sse};
use std::convert::Infallible;
use futures::stream::Stream;
// Server shared state
#[derive(Clone)]
@ -62,6 +67,80 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
Ok(())
}
#[derive(serde::Serialize)]
struct StreamEvent {
is_end: bool,
event: Option<IntermediateEvent>,
generated_text: Option<GeneratedText>,
}
async fn generate_stream(
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
let (response_tx, response_rx) = oneshot::channel();
let (input_length, validated_request) =
state.validation.validate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
}).unwrap();
// Inference
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx);
let stream = async_stream::stream! {
while let Some(item) = intermediate_rx.recv().await {
match item {
Ok(item) => {
match item {
Some(item) => {
let event_data = IntermediateEvent {
token: item.token,
token_id: item.token_id,
logprob: item.logprob,
};
let stream_event = StreamEvent {
is_end: false,
event: Some(event_data),
generated_text: None,
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
}
None => {
break
}
}
}
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
}
let response = response_rx.await.unwrap();
match response {
Ok(response) => {
let response = GeneratedText {
generated_text: response.output_text,
details: None,
};
let stream_event = StreamEvent {
is_end: true,
event: None,
generated_text: Some(response),
};
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
}
Err(err) => {
yield Ok(Event::default().data(err.to_string()));
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
/// Generate method
#[instrument(
skip(state),
@ -197,6 +276,7 @@ pub async fn run(
let app = Router::new()
.route("/", post(generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
.route("/health", get(health))
.layer(Extension(shared_state.clone()));

View File

@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch
from text_generation.models.types import GeneratedText, Batch, Intermediate
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria
@ -314,6 +314,7 @@ class CausalLM(Model):
# Finished requests
generated_texts: List[GeneratedText] = []
intermediates: List[Intermediate] = []
# Zipped iterator
iterator = zip(
@ -352,12 +353,23 @@ class CausalLM(Model):
next_token_logprob = logprobs[-1, next_token]
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
next_token_sq = next_token.squeeze()
decoded_next_token = self.tokenizer.decode(
next_token_sq, clean_up_tokenization_spaces=False
)
next_token_logprob = all_logprobs[-1]
intermediate = Intermediate(
request_id=request.id,
token=decoded_next_token,
logprob=next_token_logprob.item(),
token_id=next_token_sq.item(),
)
intermediates.append(intermediate.to_pb())
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token.squeeze(),
self.tokenizer.decode(
next_token.squeeze(), clean_up_tokenization_spaces=False
),
next_token_sq,
decoded_next_token,
)
if stop:
# Decode generated tokens
@ -399,7 +411,7 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generated_texts, None
return generated_texts, None, intermediates
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
@ -459,4 +471,4 @@ class CausalLM(Model):
max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generated_texts, next_batch
return generated_texts, next_batch, intermediates

View File

@ -50,3 +50,19 @@ class GeneratedText:
logprobs=self.logprobs,
finish_reason=self.reason,
)
@dataclass
class Intermediate:
request_id: int
token_id: int
logprob: float
token: str
def to_pb(self) -> generate_pb2.Intermediate:
return generate_pb2.Intermediate(
request_id = self.request_id,
token_id = self.token_id,
logprob = self.logprob,
token = self.token,
)

View File

@ -32,7 +32,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.device
)
generated_texts, next_batch = self.model.generate_token(batch)
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.GenerateResponse(
@ -40,6 +40,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
intermediates=intermediates,
)
async def GenerateWithCache(self, request, context):
@ -58,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else:
batch = batches[0]
generated_texts, next_batch = self.model.generate_token(batch)
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
self.cache.set(next_batch)
return generate_pb2.GenerateWithCacheResponse(
@ -66,6 +67,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
intermediates=intermediates,
)