mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
added streaming endpoint
This commit is contained in:
parent
5c01e2544c
commit
b1ef80583c
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1795,6 +1795,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"prost",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
@ -1823,6 +1824,7 @@ dependencies = [
|
||||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"clap 4.0.22",
|
||||
"futures",
|
||||
|
@ -84,6 +84,13 @@ message GeneratedText {
|
||||
string finish_reason = 7;
|
||||
}
|
||||
|
||||
message Intermediate {
|
||||
uint64 request_id = 1;
|
||||
uint32 token_id = 2;
|
||||
float logprob = 3;
|
||||
string token = 4;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
@ -94,6 +101,8 @@ message GenerateResponse {
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
|
||||
repeated Intermediate intermediates = 3;
|
||||
}
|
||||
|
||||
message GenerateWithCacheRequest {
|
||||
@ -106,4 +115,6 @@ message GenerateWithCacheResponse {
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
|
||||
repeated Intermediate intermediates = 3;
|
||||
}
|
||||
|
@ -22,7 +22,8 @@ serde = "1.0.145"
|
||||
serde_json = "1.0.85"
|
||||
thiserror = "1.0.37"
|
||||
tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "net"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||
async-stream = "0.3.3"
|
||||
|
||||
|
@ -12,6 +12,7 @@ tonic = "^0.6"
|
||||
tower = "^0.4"
|
||||
tracing = "^0.1"
|
||||
tracing-error = "^0.2"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.6.2"
|
||||
|
@ -73,7 +73,7 @@ impl Client {
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
||||
let response = self
|
||||
.stub
|
||||
@ -81,7 +81,7 @@ impl Client {
|
||||
.instrument(info_span!("generate"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
Ok((response.generated_texts, response.batch, response.intermediates))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batch
|
||||
@ -92,7 +92,7 @@ impl Client {
|
||||
pub async fn generate_with_cache(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
|
||||
let response = self
|
||||
.stub
|
||||
@ -100,6 +100,6 @@ impl Client {
|
||||
.instrument(info_span!("generate_with_cache"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
Ok((response.generated_texts, response.batch, response.intermediates))
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
@ -35,3 +35,10 @@ impl From<transport::Error> for ClientError {
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct IntermediateEvent {
|
||||
pub token: String,
|
||||
pub token_id: u32,
|
||||
pub logprob: f32,
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, GeneratedText};
|
||||
use crate::{Batch, Client, GeneratedText, Intermediate};
|
||||
use futures::future::join_all;
|
||||
use futures::future::select_all;
|
||||
use tonic::transport::Uri;
|
||||
@ -41,7 +41,7 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
@ -59,7 +59,7 @@ impl ShardedClient {
|
||||
pub async fn generate_with_cache(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
|
@ -5,9 +5,9 @@ use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient, Intermediate};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::sync::{oneshot, Notify, mpsc};
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
||||
@ -50,6 +50,36 @@ impl Batcher {
|
||||
Self { db, shared }
|
||||
}
|
||||
|
||||
/// Add a new request to the database and return a future that will generate the text
|
||||
pub(crate) fn infer_stream(
|
||||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
intermediate_tx: mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>,
|
||||
response_tx: oneshot::Sender<Result<InferResponse, ClientError>>,
|
||||
) {
|
||||
// Try to append the request to the database
|
||||
self.db.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
intermediate_tx: Some(intermediate_tx),
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the database that needs
|
||||
// to be batched
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// // Await on the response from the background task
|
||||
// // We can safely unwrap as the background task will never drop the sender
|
||||
// response_rx
|
||||
// .await
|
||||
// .unwrap()
|
||||
// .map_err(|err| InferError::GenerationError(err.to_string()))
|
||||
}
|
||||
|
||||
/// Add a new request to the database and return a future that will generate the text
|
||||
pub(crate) async fn infer(
|
||||
&self,
|
||||
@ -63,6 +93,7 @@ impl Batcher {
|
||||
self.db.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
intermediate_tx: None,
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
batch_time: None,
|
||||
@ -153,13 +184,13 @@ async fn batching_task(
|
||||
|
||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>, Vec<Intermediate>), ClientError>>,
|
||||
request_ids: Vec<u64>,
|
||||
db: &Db,
|
||||
) -> Option<Batch> {
|
||||
match future.await {
|
||||
Ok((generated_texts, next_batch)) => {
|
||||
send_generated(generated_texts, db);
|
||||
Ok((generated_texts, next_batch, intermediates)) => {
|
||||
send_generated(generated_texts, intermediates, db);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
@ -181,12 +212,29 @@ fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||
}
|
||||
|
||||
/// Send `generated_text` to the Batcher for all `finished`
|
||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
fn send_generated(finished: Vec<GeneratedText>, intermediates: Vec<Intermediate>, db: &Db) {
|
||||
// zip with intermediates
|
||||
intermediates.into_iter().for_each(|intermediate| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let guard = db.get_mutex_guard();
|
||||
let entry = guard.entries.get(&intermediate.request_id).expect("ID not found in db. This is a bug.");
|
||||
|
||||
|
||||
if let Some(tx) = &entry.intermediate_tx {
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
tx.send(Ok(Some(intermediate))).unwrap_or(());
|
||||
}
|
||||
});
|
||||
|
||||
finished.into_iter().for_each(|output| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let entry = db
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
|
||||
if let Some(tx) = &entry.intermediate_tx {
|
||||
tx.send(Ok(None)).unwrap_or(());
|
||||
}
|
||||
|
||||
let response = InferResponse {
|
||||
output_text: output.output_text,
|
||||
|
@ -5,7 +5,7 @@ use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Intermediate,
|
||||
};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
@ -17,6 +17,8 @@ pub(crate) struct Entry {
|
||||
pub request: GenerateRequest,
|
||||
/// Response sender to communicate between the Batcher and the batching_task
|
||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||
/// Intermediate sender to communicate between the Batcher and the batching_task
|
||||
pub intermediate_tx: Option<tokio::sync::mpsc::UnboundedSender<Result<Option<Intermediate>, ClientError>>>,
|
||||
/// Number of tokens in the input
|
||||
pub input_length: usize,
|
||||
/// Instant when this entry was created
|
||||
@ -39,9 +41,9 @@ pub struct Shared {
|
||||
|
||||
/// Database State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
pub(crate) struct State {
|
||||
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
||||
entries: BTreeMap<u64, Entry>,
|
||||
pub(crate) entries: BTreeMap<u64, Entry>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
@ -118,6 +120,10 @@ impl Db {
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
||||
pub(crate) fn get_mutex_guard(&self) -> parking_lot::MutexGuard<State> {
|
||||
self.shared.state.lock()
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
pub(crate) fn next_batch(
|
||||
&self,
|
||||
|
@ -8,12 +8,17 @@ use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_client::{ShardedClient, IntermediateEvent};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
use tokio::sync::{oneshot, mpsc};
|
||||
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use std::convert::Infallible;
|
||||
use futures::stream::Stream;
|
||||
|
||||
// Server shared state
|
||||
#[derive(Clone)]
|
||||
@ -62,6 +67,80 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct StreamEvent {
|
||||
is_end: bool,
|
||||
event: Option<IntermediateEvent>,
|
||||
generated_text: Option<GeneratedText>,
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
let (intermediate_tx, mut intermediate_rx) = mpsc::unbounded_channel();
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
let (input_length, validated_request) =
|
||||
state.validation.validate(req.0).await.map_err(|err| {
|
||||
tracing::error!("{}", err.to_string());
|
||||
err
|
||||
}).unwrap();
|
||||
|
||||
// Inference
|
||||
state.batcher.infer_stream(input_length, validated_request, intermediate_tx, response_tx);
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
while let Some(item) = intermediate_rx.recv().await {
|
||||
match item {
|
||||
Ok(item) => {
|
||||
match item {
|
||||
Some(item) => {
|
||||
let event_data = IntermediateEvent {
|
||||
token: item.token,
|
||||
token_id: item.token_id,
|
||||
logprob: item.logprob,
|
||||
};
|
||||
let stream_event = StreamEvent {
|
||||
is_end: false,
|
||||
event: Some(event_data),
|
||||
generated_text: None,
|
||||
};
|
||||
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||
}
|
||||
None => {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
yield Ok(Event::default().data(err.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
let response = response_rx.await.unwrap();
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let response = GeneratedText {
|
||||
generated_text: response.output_text,
|
||||
details: None,
|
||||
};
|
||||
let stream_event = StreamEvent {
|
||||
is_end: true,
|
||||
event: None,
|
||||
generated_text: Some(response),
|
||||
};
|
||||
yield Ok(Event::default().data(serde_json::to_string(&stream_event).unwrap()));
|
||||
}
|
||||
Err(err) => {
|
||||
yield Ok(Event::default().data(err.to_string()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
/// Generate method
|
||||
#[instrument(
|
||||
skip(state),
|
||||
@ -197,6 +276,7 @@ pub async fn run(
|
||||
let app = Router::new()
|
||||
.route("/", post(generate))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/", get(health))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(shared_state.clone()));
|
||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.models.types import GeneratedText, Batch, Intermediate
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
@ -314,6 +314,7 @@ class CausalLM(Model):
|
||||
|
||||
# Finished requests
|
||||
generated_texts: List[GeneratedText] = []
|
||||
intermediates: List[Intermediate] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -352,12 +353,23 @@ class CausalLM(Model):
|
||||
next_token_logprob = logprobs[-1, next_token]
|
||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||
|
||||
next_token_sq = next_token.squeeze()
|
||||
decoded_next_token = self.tokenizer.decode(
|
||||
next_token_sq, clean_up_tokenization_spaces=False
|
||||
)
|
||||
next_token_logprob = all_logprobs[-1]
|
||||
intermediate = Intermediate(
|
||||
request_id=request.id,
|
||||
token=decoded_next_token,
|
||||
logprob=next_token_logprob.item(),
|
||||
token_id=next_token_sq.item(),
|
||||
)
|
||||
intermediates.append(intermediate.to_pb())
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token.squeeze(),
|
||||
self.tokenizer.decode(
|
||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||
),
|
||||
next_token_sq,
|
||||
decoded_next_token,
|
||||
)
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
@ -399,7 +411,7 @@ class CausalLM(Model):
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return generated_texts, None
|
||||
return generated_texts, None, intermediates
|
||||
|
||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||
@ -459,4 +471,4 @@ class CausalLM(Model):
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
return generated_texts, next_batch
|
||||
return generated_texts, next_batch, intermediates
|
||||
|
@ -50,3 +50,19 @@ class GeneratedText:
|
||||
logprobs=self.logprobs,
|
||||
finish_reason=self.reason,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Intermediate:
|
||||
request_id: int
|
||||
token_id: int
|
||||
logprob: float
|
||||
token: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.Intermediate:
|
||||
return generate_pb2.Intermediate(
|
||||
request_id = self.request_id,
|
||||
token_id = self.token_id,
|
||||
logprob = self.logprob,
|
||||
token = self.token,
|
||||
)
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
)
|
||||
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateResponse(
|
||||
@ -40,6 +40,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
intermediates=intermediates,
|
||||
)
|
||||
|
||||
async def GenerateWithCache(self, request, context):
|
||||
@ -58,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
generated_texts, next_batch, intermediates = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateWithCacheResponse(
|
||||
@ -66,6 +67,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
intermediates=intermediates,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user