2023-01-31 16:04:00 +00:00
/// Batching and inference logic
use crate ::validation ::{ Validation , ValidationError } ;
use crate ::GenerateRequest ;
2023-02-02 13:59:27 +00:00
use crate ::{ Entry , Queue , Token } ;
2023-01-31 16:04:00 +00:00
use nohash_hasher ::IntMap ;
use std ::future ::Future ;
use std ::sync ::Arc ;
use text_generation_client ::{
Batch , ClientError , GeneratedText , Generation , PrefillTokens , ShardedClient ,
} ;
use thiserror ::Error ;
use tokio ::sync ::{ mpsc , Notify , Semaphore , TryAcquireError } ;
use tokio ::time ::Instant ;
use tokio_stream ::wrappers ::UnboundedReceiverStream ;
use tokio_stream ::StreamExt ;
2023-02-13 12:02:45 +00:00
use tracing ::{ info_span , instrument , Instrument , Span } ;
2023-01-31 16:04:00 +00:00
/// Inference struct
#[ derive(Clone) ]
pub struct Infer {
/// Validation
validation : Validation ,
2023-02-02 13:59:27 +00:00
/// Request queue
queue : Queue ,
2023-01-31 16:04:00 +00:00
/// Shared state
shared : Arc < Shared > ,
/// Inference limit
limit_concurrent_requests : Arc < Semaphore > ,
}
/// Infer shared state
struct Shared {
/// Batching background Tokio task notifier
batching_task : Notify ,
}
impl Infer {
pub ( crate ) fn new (
client : ShardedClient ,
validation : Validation ,
max_batch_size : usize ,
max_waiting_tokens : usize ,
max_concurrent_requests : usize ,
) -> Self {
// Infer shared state
2023-02-02 13:59:27 +00:00
let queue = Queue ::new ( ) ;
2023-01-31 16:04:00 +00:00
let shared = Arc ::new ( Shared {
batching_task : Notify ::new ( ) ,
} ) ;
// Spawn batching background task that contains all the inference logic
tokio ::spawn ( batching_task (
client ,
max_batch_size ,
max_waiting_tokens ,
2023-02-02 13:59:27 +00:00
queue . clone ( ) ,
2023-01-31 16:04:00 +00:00
shared . clone ( ) ,
) ) ;
// Inference limit with a semaphore
let semaphore = Arc ::new ( Semaphore ::new ( max_concurrent_requests ) ) ;
Self {
validation ,
2023-02-02 13:59:27 +00:00
queue ,
2023-01-31 16:04:00 +00:00
shared ,
limit_concurrent_requests : semaphore ,
}
}
2023-02-02 13:59:27 +00:00
/// Add a new request to the queue and return a stream of InferStreamResponse
2023-02-13 12:02:45 +00:00
#[ instrument(skip(self)) ]
2023-01-31 16:04:00 +00:00
pub ( crate ) async fn generate_stream (
& self ,
request : GenerateRequest ,
) -> Result < UnboundedReceiverStream < Result < InferStreamResponse , InferError > > , InferError > {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
2023-02-13 12:02:45 +00:00
let permit = self
. clone ( )
. limit_concurrent_requests
. try_acquire_owned ( )
. map_err ( | err | {
tracing ::error! ( " {err} " ) ;
err
} ) ? ;
2023-01-31 16:04:00 +00:00
// Validate request
let valid_request = self . validation . validate ( request ) . await ? ;
// MPSC channel to communicate with the background batching task
let ( response_tx , response_rx ) = mpsc ::unbounded_channel ( ) ;
2023-02-02 13:59:27 +00:00
// Append the request to the queue
self . queue . append ( Entry {
2023-01-31 16:04:00 +00:00
request : valid_request ,
response_tx ,
2023-02-13 12:02:45 +00:00
span : Span ::current ( ) ,
temp_span : None ,
queue_time : Instant ::now ( ) ,
2023-01-31 16:04:00 +00:00
batch_time : None ,
_permit : permit ,
} ) ;
2023-02-02 13:59:27 +00:00
// Notify the background task that we have a new entry in the queue that needs
2023-01-31 16:04:00 +00:00
// to be batched
self . shared . batching_task . notify_one ( ) ;
// Return stream
Ok ( UnboundedReceiverStream ::new ( response_rx ) )
}
2023-02-02 13:59:27 +00:00
/// Add a new request to the queue and return a InferResponse
2023-02-13 12:02:45 +00:00
#[ instrument(skip(self)) ]
2023-01-31 16:04:00 +00:00
pub ( crate ) async fn generate (
& self ,
request : GenerateRequest ,
) -> Result < InferResponse , InferError > {
// Create stream
let mut stream = self . generate_stream ( request ) . await ? ;
// Return values
let mut result_prefill = Vec ::new ( ) ;
let mut result_tokens = Vec ::new ( ) ;
let mut result_generated_text = None ;
let mut result_start = None ;
let mut result_queued = None ;
// Iterate on stream
while let Some ( response ) = stream . next ( ) . await {
match response ? {
// Add prefill tokens
InferStreamResponse ::Prefill ( tokens ) = > {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
. ids
. into_iter ( )
. zip ( tokens . logprobs . into_iter ( ) )
. zip ( tokens . texts . into_iter ( ) )
2023-02-03 11:43:37 +00:00
. map ( | ( ( id , logprob ) , text ) | Token { id , text , logprob } )
2023-01-31 16:04:00 +00:00
. collect ( ) ;
}
// Push last token
InferStreamResponse ::Token ( token ) = > result_tokens . push ( token ) ,
// Final message
// Set return values
InferStreamResponse ::End {
token ,
generated_text ,
start ,
queued ,
} = > {
result_tokens . push ( token ) ;
result_generated_text = Some ( generated_text ) ;
result_start = Some ( start ) ;
result_queued = Some ( queued )
}
}
}
// Check that we received a `InferStreamResponse::End` message
if let ( Some ( generated_text ) , Some ( queued ) , Some ( start ) ) =
( result_generated_text , result_queued , result_start )
{
Ok ( InferResponse {
prefill : result_prefill ,
tokens : result_tokens ,
generated_text ,
queued ,
start ,
} )
} else {
2023-02-13 12:02:45 +00:00
let err = InferError ::IncompleteGeneration ;
tracing ::error! ( " {err} " ) ;
Err ( err )
2023-01-31 16:04:00 +00:00
}
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
async fn batching_task (
mut client : ShardedClient ,
max_batch_size : usize ,
max_waiting_tokens : usize ,
2023-02-02 13:59:27 +00:00
queue : Queue ,
2023-01-31 16:04:00 +00:00
shared : Arc < Shared > ,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = ( max_batch_size / 2 ) as u32 ;
// Infinite loop
loop {
// Wait for a notification from the Infer struct
shared . batching_task . notified ( ) . await ;
2023-02-02 13:59:27 +00:00
// Get the next batch from the queue
2023-01-31 16:04:00 +00:00
// This batch might be smaller than the maximum batch size if there are not enough requests
2023-02-02 13:59:27 +00:00
// waiting in the queue
2023-02-13 12:02:45 +00:00
while let Some ( ( mut entries , batch , span ) ) = queue . next_batch ( None , max_batch_size ) . await {
let mut cached_batch = wrap_future ( client . prefill ( batch ) , & mut entries )
. instrument ( span )
. await ;
2023-01-31 16:04:00 +00:00
let mut waiting_tokens = 1 ;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some ( batch ) = cached_batch {
// Get current batch info
let batch_size = batch . size ;
let mut batches = vec! [ batch ] ;
// If the current batch is too small, we try to add more requests to it
if batch_size < = limit_min_batch_size {
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens > = max_waiting_tokens = > None ,
// Minimum size criteria
_ = > Some ( limit_min_batch_size as usize ) ,
} ;
// Try to get a new batch
2023-02-13 12:02:45 +00:00
if let Some ( ( mut new_entries , new_batch , span ) ) = queue
2023-02-02 13:59:27 +00:00
. next_batch ( min_size , max_batch_size - batch_size as usize )
. await
2023-01-31 16:04:00 +00:00
{
2023-02-13 12:02:45 +00:00
let new_batch_size = new_batch . size ;
entries . iter_mut ( ) . for_each ( | ( _ , entry ) | {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span! ( parent : & entry . span , " waiting " , batch_size = new_batch_size ) ;
// Add relationship
entry_waiting_span . follows_from ( & span ) ;
// Update entry
entry . temp_span = Some ( entry_waiting_span ) ;
} ) ;
2023-01-31 16:04:00 +00:00
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
2023-02-13 12:02:45 +00:00
wrap_future ( client . prefill ( new_batch ) , & mut new_entries )
. instrument ( span )
. await ;
2023-01-31 16:04:00 +00:00
// Reset waiting counter
waiting_tokens = 1 ;
// Extend current batch with the new batch
if let Some ( new_cached_batch ) = new_cached_batch {
entries . extend ( new_entries ) ;
batches . push ( new_cached_batch ) ;
}
}
}
2023-02-13 12:02:45 +00:00
// Create span for this batch to add context to inference calls
let next_batch_size = entries . len ( ) ;
let next_batch_span =
info_span! ( parent : None , " batch " , batch_size = next_batch_size ) ;
entries . iter_mut ( ) . for_each ( | ( _ , entry ) | {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span! ( parent : & entry . span , " infer " , batch_size = next_batch_size ) ;
// Add relationship
entry_batch_span . follows_from ( & next_batch_span ) ;
// Update entry
entry . temp_span = Some ( entry_batch_span ) ;
} ) ;
2023-01-31 16:04:00 +00:00
2023-02-13 12:02:45 +00:00
cached_batch = wrap_future ( client . decode ( batches ) , & mut entries )
. instrument ( next_batch_span )
. await ;
2023-01-31 16:04:00 +00:00
waiting_tokens + = 1 ;
}
}
}
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
2023-02-13 12:02:45 +00:00
#[ instrument(skip_all) ]
2023-01-31 16:04:00 +00:00
async fn wrap_future (
future : impl Future < Output = Result < ( Vec < Generation > , Option < Batch > ) , ClientError > > ,
entries : & mut IntMap < u64 , Entry > ,
) -> Option < Batch > {
match future . await {
Ok ( ( generations , next_batch ) ) = > {
send_generations ( generations , entries ) ;
next_batch
}
// If we have an error, we discard the whole batch
Err ( err ) = > {
2023-02-13 12:02:45 +00:00
send_errors ( err , entries ) ;
2023-01-31 16:04:00 +00:00
None
}
}
}
/// Send errors to Infer for all `entries`
2023-02-13 12:02:45 +00:00
#[ instrument(skip_all) ]
fn send_errors ( error : ClientError , entries : & mut IntMap < u64 , Entry > ) {
2023-01-31 16:04:00 +00:00
entries . drain ( ) . for_each ( | ( _ , entry ) | {
2023-02-13 12:02:45 +00:00
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span! ( parent : entry . temp_span . as_ref ( ) . expect ( " batch_span is None. This is a bug. " ) , " send_error " ) . entered ( ) ;
let err = InferError ::GenerationError ( error . to_string ( ) ) ;
tracing ::error! ( " {err} " ) ;
2023-01-31 16:04:00 +00:00
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
. response_tx
2023-02-13 12:02:45 +00:00
. send ( Err ( err ) )
2023-01-31 16:04:00 +00:00
. unwrap_or ( ( ) ) ;
} ) ;
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
2023-02-13 12:02:45 +00:00
#[ instrument(skip_all) ]
2023-01-31 16:04:00 +00:00
fn send_generations ( generations : Vec < Generation > , entries : & mut IntMap < u64 , Entry > ) {
generations . into_iter ( ) . for_each ( | generation | {
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
. get ( & generation . request_id )
. expect ( " ID not found in entries. This is a bug. " ) ;
2023-02-13 12:02:45 +00:00
// Create and enter a span to link this function back to the entry
let _generation_span = info_span! ( parent : entry . temp_span . as_ref ( ) . expect ( " batch_span is None. This is a bug. " ) , " send_generation " , generation = ? generation ) . entered ( ) ;
2023-01-31 16:04:00 +00:00
if let Some ( prefill_tokens ) = generation . prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
. response_tx
. send ( Ok ( InferStreamResponse ::Prefill ( prefill_tokens ) ) )
. unwrap_or ( ( ) ) ;
}
// Create last Token
2023-02-03 11:43:37 +00:00
let token = Token {
id : generation . token_id ,
text : generation . token_text ,
logprob : generation . token_logprob ,
} ;
2023-01-31 16:04:00 +00:00
if let Some ( generated_text ) = generation . generated_text {
// Remove entry as this is the last message
// We can `expect` here as the request id should always be in the entries
let entry = entries
. remove ( & generation . request_id )
. expect ( " ID not found in entries. This is a bug. " ) ;
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
. response_tx
. send ( Ok ( InferStreamResponse ::End {
token ,
generated_text ,
2023-02-13 12:02:45 +00:00
queued : entry . queue_time ,
2023-01-31 16:04:00 +00:00
start : entry . batch_time . unwrap ( ) ,
} ) )
. unwrap_or ( ( ) ) ;
} else {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
. response_tx
. send ( Ok ( InferStreamResponse ::Token ( token ) ) )
. unwrap_or ( ( ) ) ;
}
} ) ;
}
#[ derive(Debug) ]
pub ( crate ) enum InferStreamResponse {
// Optional first message
Prefill ( PrefillTokens ) ,
// Intermediate messages
Token ( Token ) ,
// Last message
End {
token : Token ,
generated_text : GeneratedText ,
start : Instant ,
queued : Instant ,
} ,
}
#[ derive(Debug) ]
pub ( crate ) struct InferResponse {
pub ( crate ) prefill : Vec < Token > ,
pub ( crate ) tokens : Vec < Token > ,
pub ( crate ) generated_text : GeneratedText ,
pub ( crate ) queued : Instant ,
pub ( crate ) start : Instant ,
}
#[ derive(Debug, Error) ]
pub enum InferError {
#[ error( " Request failed during generation: {0} " ) ]
GenerationError ( String ) ,
#[ error( " Model is overloaded " ) ]
Overloaded ( #[ from ] TryAcquireError ) ,
#[ error( " Input validation error: {0} " ) ]
ValidationError ( #[ from ] ValidationError ) ,
#[ error( " Incomplete generation " ) ]
IncompleteGeneration ,
}