mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
wip
This commit is contained in:
parent
54fec93193
commit
432566d931
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -1829,6 +1829,7 @@ dependencies = [
|
||||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"clap 4.0.22",
|
||||
"futures",
|
||||
@ -1841,6 +1842,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
@ -7,10 +7,10 @@ service TextGenerationService {
|
||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Generate tokens for a batch
|
||||
rpc Generate (GenerateRequest) returns (GenerateResponse);
|
||||
/// Generate tokens for a list of cached batches
|
||||
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
@ -70,44 +70,60 @@ message Batch {
|
||||
}
|
||||
|
||||
message GeneratedText {
|
||||
/// Request
|
||||
Request request = 1;
|
||||
/// Output
|
||||
string output_text = 2;
|
||||
string text = 1;
|
||||
/// Number of generated tokens
|
||||
uint32 generated_tokens = 3;
|
||||
/// Tokens
|
||||
repeated string tokens = 4;
|
||||
/// Token IDs
|
||||
repeated uint32 token_ids = 5;
|
||||
/// Logprobs
|
||||
repeated float logprobs = 6;
|
||||
uint32 generated_tokens = 2;
|
||||
/// Finish reason
|
||||
string finish_reason = 7;
|
||||
string finish_reason = 3;
|
||||
/// Seed
|
||||
optional uint64 seed = 8;
|
||||
optional uint64 seed = 4;
|
||||
}
|
||||
|
||||
message GenerateRequest {
|
||||
message PrefillTokens {
|
||||
/// Prefill Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Prefill Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Prefill tokens
|
||||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
/// Prefill tokens (optional)
|
||||
PrefillTokens prefill_tokens = 2;
|
||||
/// Token ID
|
||||
uint32 token_id = 3;
|
||||
/// Logprob
|
||||
float token_logprob = 4;
|
||||
/// Text
|
||||
string token_text = 5;
|
||||
/// Complete generated text
|
||||
GeneratedText generated_text = 6;
|
||||
}
|
||||
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
}
|
||||
|
||||
message GenerateResponse {
|
||||
/// Finished requests
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
message PrefillResponse {
|
||||
/// Generation
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
||||
message GenerateWithCacheRequest {
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated Batch batches = 1;
|
||||
}
|
||||
|
||||
message GenerateWithCacheResponse {
|
||||
/// Finished requests
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
message DecodeResponse {
|
||||
/// Decodes
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
}
|
@ -13,6 +13,7 @@ name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-stream = "0.3.3"
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
@ -25,6 +26,7 @@ serde_json = "1.0.85"
|
||||
thiserror = "1.0.37"
|
||||
tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.11"
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||
|
||||
|
@ -73,15 +73,15 @@ impl Client {
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
|
||||
let response = self
|
||||
.stub
|
||||
.generate(request)
|
||||
.instrument(info_span!("generate"))
|
||||
.prefill(request)
|
||||
.instrument(info_span!("prefill"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
Ok((response.generations, response.batch))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batch
|
||||
@ -89,17 +89,17 @@ impl Client {
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate_with_cache(
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(GenerateWithCacheRequest { batches });
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches });
|
||||
let response = self
|
||||
.stub
|
||||
.generate_with_cache(request)
|
||||
.instrument(info_span!("generate_with_cache"))
|
||||
.decode(request)
|
||||
.instrument(info_span!("decode"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
Ok((response.generations, response.batch))
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,8 @@ mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
|
@ -1,6 +1,6 @@
|
||||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, GeneratedText};
|
||||
use crate::{Batch, Client, Generation};
|
||||
use futures::future::join_all;
|
||||
use futures::future::select_all;
|
||||
use tonic::transport::Uri;
|
||||
@ -41,11 +41,11 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.generate(batch.clone())))
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
let (result, _, _) = select_all(futures).await;
|
||||
@ -56,14 +56,14 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate_with_cache(
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
let (result, _, _) = select_all(futures).await;
|
||||
|
@ -1,15 +1,17 @@
|
||||
/// Batching and inference logic
|
||||
use crate::{Db, Entry};
|
||||
use crate::{Db, Entry, Token};
|
||||
use crate::{ErrorResponse, GenerateRequest};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, Generation, ShardedClient};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Batcher
|
||||
@ -51,14 +53,14 @@ impl Batcher {
|
||||
Self { db, shared }
|
||||
}
|
||||
|
||||
/// Add a new request to the database and return a future that will generate the text
|
||||
pub(crate) async fn infer(
|
||||
/// Add a new request to the database and return a stream of tokens
|
||||
pub(crate) fn infer_stream(
|
||||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
// One shot channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
) -> UnboundedReceiverStream<Result<InferStreamResponse, InferError>> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Try to append the request to the database
|
||||
self.db.append(Entry {
|
||||
@ -73,12 +75,45 @@ impl Batcher {
|
||||
// to be batched
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Await on the response from the background task
|
||||
// We can safely unwrap as the background task will never drop the sender
|
||||
response_rx
|
||||
.await
|
||||
.unwrap()
|
||||
.map_err(|err| InferError::GenerationError(err.to_string()))
|
||||
// Return stream
|
||||
UnboundedReceiverStream::new(response_rx)
|
||||
}
|
||||
|
||||
pub(crate) async fn infer(
|
||||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
let mut stream = self.infer_stream(input_length, request);
|
||||
|
||||
let mut result_tokens = Vec::new();
|
||||
let mut result_generated_text = None;
|
||||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response? {
|
||||
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||
result_tokens.extend(prefill_tokens)
|
||||
}
|
||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||
InferStreamResponse::End {
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
} => {
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(InferResponse {
|
||||
tokens: result_tokens,
|
||||
generated_text: result_generated_text.unwrap(),
|
||||
queued: result_queued.unwrap(),
|
||||
start: result_start.unwrap(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +141,7 @@ async fn batching_task(
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
while let Some((mut entries, batch)) = db.next_batch(None, max_batch_size) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), &mut entries).await;
|
||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
@ -132,7 +167,7 @@ async fn batching_task(
|
||||
{
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), &mut new_entries).await;
|
||||
wrap_future(client.prefill(new_batch), &mut new_entries).await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
@ -143,7 +178,7 @@ async fn batching_task(
|
||||
}
|
||||
}
|
||||
|
||||
cached_batch = wrap_future(client.generate_with_cache(batches), &mut entries).await;
|
||||
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
}
|
||||
@ -152,12 +187,12 @@ async fn batching_task(
|
||||
|
||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
match future.await {
|
||||
Ok((generated_texts, next_batch)) => {
|
||||
send_generated(generated_texts, entries);
|
||||
Ok((generations, next_batch)) => {
|
||||
send_generated(generations, entries);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
@ -172,47 +207,79 @@ async fn wrap_future(
|
||||
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(InferError::GenerationError(error.to_string())))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
/// Send `generated_text` to the Batcher for all `finished`
|
||||
fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>) {
|
||||
finished.into_iter().for_each(|output| {
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
fn send_generated(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let entry = entries
|
||||
.remove(&output.request.unwrap().id)
|
||||
.get(&generation.request_id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
let response = InferResponse {
|
||||
output_text: output.output_text,
|
||||
generated_tokens: output.generated_tokens,
|
||||
token_ids: output.token_ids,
|
||||
tokens: output.tokens,
|
||||
logprobs: output.logprobs,
|
||||
finish_reason: output.finish_reason,
|
||||
seed: output.seed,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||
end: Instant::now(),
|
||||
};
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Ok(response)).unwrap_or(());
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
let tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs.into_iter())
|
||||
.zip(prefill_tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
||||
.collect();
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(tokens)))
|
||||
.unwrap_or(());
|
||||
}
|
||||
|
||||
let token = Token(
|
||||
generation.token_id,
|
||||
generation.token_text,
|
||||
generation.token_logprob,
|
||||
);
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Token(token)))
|
||||
.unwrap_or(());
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
let entry = entries
|
||||
.remove(&generation.request_id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::End {
|
||||
generated_text,
|
||||
queued: entry.time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))
|
||||
.unwrap_or(());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
Prefill(Vec<Token>),
|
||||
Token(Token),
|
||||
End {
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) output_text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) token_ids: Vec<u32>,
|
||||
pub(crate) tokens: Vec<String>,
|
||||
pub(crate) logprobs: Vec<f32>,
|
||||
pub(crate) finish_reason: String,
|
||||
pub(crate) seed: Option<u64>,
|
||||
pub(crate) tokens: Vec<Token>,
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
pub(crate) seed: Option<u64>
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) end: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -1,14 +1,15 @@
|
||||
use crate::batcher::InferError;
|
||||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::InferResponse;
|
||||
use crate::batcher::InferStreamResponse;
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
Batch, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// Database entry
|
||||
@ -17,7 +18,7 @@ pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: GenerateRequest,
|
||||
/// Response sender to communicate between the Batcher and the batching_task
|
||||
pub response_tx: Sender<Result<InferResponse, ClientError>>,
|
||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Number of tokens in the input
|
||||
pub input_length: usize,
|
||||
/// Instant when this entry was created
|
||||
|
@ -4,7 +4,7 @@ mod db;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use batcher::{Batcher, InferResponse};
|
||||
use batcher::Batcher;
|
||||
use db::{Db, Entry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validation::Validation;
|
||||
@ -69,12 +69,15 @@ pub(crate) struct GenerateRequest {
|
||||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Token(u32, String, f32);
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct Details {
|
||||
pub finish_reason: String,
|
||||
pub generated_tokens: u32,
|
||||
pub seed: Option<u64>,
|
||||
pub tokens: Vec<(u32, String, f32)>,
|
||||
pub tokens: Vec<Token>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -1,11 +1,14 @@
|
||||
use crate::batcher::InferStreamResponse;
|
||||
use crate::{
|
||||
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum::{BoxError, Json, Router};
|
||||
use futures::Stream;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardedClient;
|
||||
@ -13,6 +16,7 @@ use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
// Server shared state
|
||||
@ -111,21 +115,12 @@ async fn generate(
|
||||
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => {
|
||||
let tokens = response
|
||||
.token_ids
|
||||
.into_iter()
|
||||
.zip(response.tokens.into_iter())
|
||||
.zip(response.logprobs.into_iter())
|
||||
.map(|((id, text), logprob)| (id, text, logprob))
|
||||
.collect();
|
||||
Some(Details {
|
||||
seed: response.seed,
|
||||
finish_reason: response.finish_reason,
|
||||
generated_tokens: response.generated_tokens,
|
||||
tokens,
|
||||
})
|
||||
}
|
||||
true => Some(Details {
|
||||
finish_reason: response.generated_text.finish_reason,
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
tokens: response.tokens,
|
||||
seed: response.seed,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
||||
@ -133,8 +128,8 @@ async fn generate(
|
||||
let total_time = start_time.elapsed();
|
||||
let validation_time = response.queued - start_time;
|
||||
let queue_time = response.start - response.queued;
|
||||
let inference_time = response.end - response.start;
|
||||
let time_per_token = inference_time / response.generated_tokens;
|
||||
let inference_time = Instant::now() - response.start;
|
||||
let time_per_token = inference_time / response.generated_text.generated_tokens;
|
||||
|
||||
// Headers
|
||||
let mut headers = HeaderMap::new();
|
||||
@ -166,16 +161,57 @@ async fn generate(
|
||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||
tracing::Span::current().record("seed", format!("{:?}", response.seed));
|
||||
tracing::info!("Output: {}", response.output_text);
|
||||
tracing::info!("Output: {}", response.generated_text.text);
|
||||
|
||||
// Send response
|
||||
let response = vec![GeneratedText {
|
||||
generated_text: response.output_text,
|
||||
generated_text: response.generated_text.text,
|
||||
details,
|
||||
}];
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, BoxError>>> {
|
||||
let stream = async_stream::stream! {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(| err | {
|
||||
tracing::error!("Model is overloaded");
|
||||
err
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
let (input_length, validated_request) =
|
||||
state.validation.validate(req.0).await.map_err(|err| {
|
||||
tracing::error!("{}", err);
|
||||
err
|
||||
})?;
|
||||
|
||||
// Inference
|
||||
let mut response_stream = state
|
||||
.batcher
|
||||
.infer_stream(input_length, validated_request);
|
||||
|
||||
while let Some(response) = response_stream.next().await {
|
||||
match response {
|
||||
Ok(response) => {
|
||||
if let InferStreamResponse::Token(token) = response {
|
||||
yield Ok(Event::default().json_data(token).unwrap());
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("{}", err.to_string());
|
||||
yield Ok(Event::default().data(err.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
@ -201,6 +237,7 @@ pub async fn run(
|
||||
let app = Router::new()
|
||||
.route("/", post(generate))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/", get(health))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(shared_state.clone()));
|
||||
|
@ -1,6 +1,6 @@
|
||||
gen-server:
|
||||
# Compile protos
|
||||
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||
#pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||
mkdir text_generation/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
|
@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch
|
||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
@ -23,7 +23,6 @@ class CausalLMBatch(Batch):
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[torch.Tensor]
|
||||
all_logprobs: List[Optional[torch.Tensor]]
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
input_lengths: List[int]
|
||||
@ -48,16 +47,15 @@ class CausalLMBatch(Batch):
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
input_lengths = []
|
||||
all_logprobs = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
@ -67,7 +65,6 @@ class CausalLMBatch(Batch):
|
||||
stopping_criterias.append(
|
||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
||||
)
|
||||
all_logprobs.append(None)
|
||||
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
@ -89,7 +86,6 @@ class CausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
all_input_ids=all_input_ids,
|
||||
all_logprobs=all_logprobs,
|
||||
input_lengths=input_lengths,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
@ -107,7 +103,6 @@ class CausalLMBatch(Batch):
|
||||
requests = []
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
all_logprobs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
@ -124,7 +119,6 @@ class CausalLMBatch(Batch):
|
||||
requests.extend(batch.requests)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
all_logprobs.extend(batch.all_logprobs)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
@ -151,8 +145,8 @@ class CausalLMBatch(Batch):
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
attention_mask[
|
||||
start_index:end_index, -batch.max_sequence_length :
|
||||
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||
start_index:end_index, -batch.max_sequence_length:
|
||||
] = batch.attention_mask[:, -batch.max_sequence_length:]
|
||||
|
||||
# Create empty tensor
|
||||
# position_ids is always of shape [batch_size, 1]
|
||||
@ -198,22 +192,22 @@ class CausalLMBatch(Batch):
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
if batch.keys_head_dim_last:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
:,
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1):,
|
||||
:,
|
||||
] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
|
||||
else:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1):,
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
|
||||
|
||||
past_key_values[j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1):, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
@ -225,7 +219,6 @@ class CausalLMBatch(Batch):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
all_input_ids=all_input_ids,
|
||||
all_logprobs=all_logprobs,
|
||||
input_lengths=input_lengths,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
@ -234,6 +227,9 @@ class CausalLMBatch(Batch):
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(self, model_name: str, quantize=False):
|
||||
@ -275,7 +271,7 @@ class CausalLM(Model):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
@ -288,8 +284,8 @@ class CausalLM(Model):
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
def generate_token(
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
context_manager = (
|
||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||
@ -309,14 +305,13 @@ class CausalLM(Model):
|
||||
next_batch_input_lengths = []
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
next_batch_all_logprobs = []
|
||||
|
||||
# Metadata
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
|
||||
# Finished requests
|
||||
generated_texts: List[GeneratedText] = []
|
||||
# Results
|
||||
results = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -326,55 +321,42 @@ class CausalLM(Model):
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.all_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
all_logprobs,
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||
next_token = tokens[-1].view(1, 1)
|
||||
next_token_id = tokens[-1].view(1, 1)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token])
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
if all_logprobs is None:
|
||||
# logprobs of all prompt tokens (except the first one) and the generated token
|
||||
all_logprobs = logprobs.gather(1, all_input_ids[1:])
|
||||
else:
|
||||
# logprob of the generated token
|
||||
next_token_logprob = logprobs[-1, next_token]
|
||||
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.decode(next_token_id.squeeze())
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token.squeeze(),
|
||||
self.tokenizer.decode(
|
||||
next_token.squeeze(), clean_up_tokenization_spaces=False
|
||||
),
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
generated_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
all_input_ids[-stopping_criteria.current_tokens:, 0]
|
||||
)
|
||||
output_text = request.inputs + generated_text
|
||||
# Slice with input_length to remove padding
|
||||
token_ids = all_input_ids[-new_input_length:]
|
||||
tokens = self.tokenizer.batch_decode(token_ids)
|
||||
# Add NaN for the first prompt token
|
||||
logprobs = [float("nan")] + all_logprobs[-input_length:].squeeze(
|
||||
1
|
||||
).tolist()
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
@ -382,39 +364,48 @@ class CausalLM(Model):
|
||||
else:
|
||||
seed = None
|
||||
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(
|
||||
GeneratedText(
|
||||
request=request,
|
||||
output_text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
tokens=tokens,
|
||||
token_ids=token_ids.squeeze(1).tolist(),
|
||||
logprobs=logprobs,
|
||||
reason=reason,
|
||||
seed=seed,
|
||||
)
|
||||
)
|
||||
# add to the next batch
|
||||
generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token)
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_all_logprobs.append(all_logprobs)
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 0:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs[-new_input_length:-1].gather(1, all_input_ids[
|
||||
-new_input_length:-1]).squeeze(
|
||||
1).tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False)
|
||||
prefill_tokens = PrefillTokens(prefill_token_ids,
|
||||
prefill_logprobs,
|
||||
prefill_texts)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
result = Generation(request.id, prefill_tokens, next_token_id_squeezed, next_token_logprob, next_token_text,
|
||||
generated_text)
|
||||
|
||||
results.append(result)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return generated_texts, None
|
||||
return results, None
|
||||
|
||||
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||
# from the values of the next batch
|
||||
if generated_texts:
|
||||
if len(next_batch_keep_indices) != len(batch):
|
||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||
next_batch_position_ids = batch.position_ids[next_batch_keep_indices]
|
||||
@ -461,7 +452,6 @@ class CausalLM(Model):
|
||||
position_ids=next_batch_position_ids,
|
||||
past_key_values=next_batch_past_key_values,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
all_logprobs=next_batch_all_logprobs,
|
||||
input_lengths=next_batch_input_lengths,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
@ -469,4 +459,4 @@ class CausalLM(Model):
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
return generated_texts, next_batch
|
||||
return results, next_batch
|
||||
|
@ -17,10 +17,10 @@ class Batch(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
) -> "Batch":
|
||||
raise NotImplementedError
|
||||
|
||||
@ -32,23 +32,49 @@ class Batch(ABC):
|
||||
|
||||
@dataclass
|
||||
class GeneratedText:
|
||||
request: generate_pb2.Request
|
||||
output_text: str
|
||||
text: str
|
||||
generated_tokens: int
|
||||
tokens: List[str]
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
reason: str
|
||||
finish_reason: str
|
||||
seed: Optional[int]
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(
|
||||
request=self.request,
|
||||
output_text=self.output_text,
|
||||
text=self.text,
|
||||
generated_tokens=self.generated_tokens,
|
||||
tokens=self.tokens,
|
||||
token_ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
finish_reason=self.reason,
|
||||
seed=self.seed,
|
||||
finish_reason=self.finish_reason
|
||||
seed=self.seed
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillTokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
|
||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||
return generate_pb2.PrefillTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
prefill_tokens: Optional[PrefillTokens]
|
||||
token_id: int
|
||||
token_logprob: float
|
||||
token_text: str
|
||||
generated_text: Optional[GeneratedText]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
request_id=self.request_id,
|
||||
prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None,
|
||||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
|
||||
)
|
||||
|
@ -27,22 +27,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.cache.clear()
|
||||
return generate_pb2.ClearCacheResponse()
|
||||
|
||||
async def Generate(self, request, context):
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.device
|
||||
)
|
||||
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateResponse(
|
||||
generated_texts=[
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[
|
||||
generation.to_pb() for generation in generations
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
async def GenerateWithCache(self, request, context):
|
||||
async def Decode(self, request, context):
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
||||
@ -58,12 +58,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateWithCacheResponse(
|
||||
generated_texts=[
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[
|
||||
generation.to_pb() for generation in generations
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user