2023-02-13 12:02:45 +00:00
/// Text Generation Inference webserver entrypoint
2023-02-17 17:22:00 +00:00
use axum ::http ::HeaderValue ;
2022-10-18 13:19:03 +00:00
use clap ::Parser ;
2023-02-13 12:02:45 +00:00
use opentelemetry ::sdk ::propagation ::TraceContextPropagator ;
use opentelemetry ::sdk ::trace ;
use opentelemetry ::sdk ::trace ::Sampler ;
use opentelemetry ::sdk ::Resource ;
use opentelemetry ::{ global , KeyValue } ;
use opentelemetry_otlp ::WithExportConfig ;
2022-10-17 16:27:33 +00:00
use std ::net ::{ IpAddr , Ipv4Addr , SocketAddr } ;
2023-03-06 13:39:36 +00:00
use std ::path ::Path ;
2023-05-09 11:19:31 +00:00
use std ::time ::Duration ;
2023-07-10 12:47:15 +00:00
use text_generation_client ::{ ClientError , ShardedClient } ;
2023-04-21 13:36:29 +00:00
use text_generation_router ::{ server , HubModelInfo } ;
2023-07-10 12:47:15 +00:00
use thiserror ::Error ;
2023-04-18 14:16:06 +00:00
use tokenizers ::{ FromPretrainedParameters , Tokenizer } ;
2023-02-17 17:22:00 +00:00
use tower_http ::cors ::AllowOrigin ;
2023-02-13 12:02:45 +00:00
use tracing_subscriber ::layer ::SubscriberExt ;
use tracing_subscriber ::util ::SubscriberInitExt ;
use tracing_subscriber ::{ EnvFilter , Layer } ;
2022-10-17 16:27:33 +00:00
/// App Configuration
#[ derive(Parser, Debug) ]
#[ clap(author, version, about, long_about = None) ]
struct Args {
2022-10-18 13:19:03 +00:00
#[ clap(default_value = " 128 " , long, env) ]
max_concurrent_requests : usize ,
2023-03-09 14:30:54 +00:00
#[ clap(default_value = " 2 " , long, env) ]
max_best_of : usize ,
2023-02-15 20:56:59 +00:00
#[ clap(default_value = " 4 " , long, env) ]
max_stop_sequences : usize ,
2023-08-28 09:43:47 +00:00
#[ clap(default_value = " 5 " , long, env) ]
max_top_n_tokens : u32 ,
2023-06-30 18:07:49 +00:00
#[ clap(default_value = " 1024 " , long, env) ]
2022-10-18 13:19:03 +00:00
max_input_length : usize ,
2023-06-30 18:07:49 +00:00
#[ clap(default_value = " 2048 " , long, env) ]
2023-02-15 20:56:59 +00:00
max_total_tokens : usize ,
2023-04-24 15:59:00 +00:00
#[ clap(default_value = " 1.2 " , long, env) ]
waiting_served_ratio : f32 ,
2023-06-30 17:09:59 +00:00
#[ clap(default_value = " 4096 " , long, env) ]
max_batch_prefill_tokens : u32 ,
2023-07-19 07:31:25 +00:00
#[ clap(long, env) ]
max_batch_total_tokens : Option < u32 > ,
2022-10-21 14:40:05 +00:00
#[ clap(default_value = " 20 " , long, env) ]
max_waiting_tokens : usize ,
2023-07-05 16:28:45 +00:00
#[ clap(default_value = " 0.0.0.0 " , long, env) ]
hostname : String ,
2022-10-17 16:27:33 +00:00
#[ clap(default_value = " 3000 " , long, short, env) ]
port : u16 ,
2023-04-09 18:22:27 +00:00
#[ clap(default_value = " /tmp/text-generation-server-0 " , long, env) ]
2022-10-18 13:19:03 +00:00
master_shard_uds_path : String ,
2022-10-17 16:27:33 +00:00
#[ clap(default_value = " bigscience/bloom " , long, env) ]
tokenizer_name : String ,
2023-07-13 16:59:38 +00:00
#[ clap(long, env) ]
revision : Option < String > ,
2022-10-18 13:19:03 +00:00
#[ clap(default_value = " 2 " , long, env) ]
validation_workers : usize ,
2022-11-02 16:29:56 +00:00
#[ clap(long, env) ]
json_output : bool ,
2023-02-13 12:02:45 +00:00
#[ clap(long, env) ]
otlp_endpoint : Option < String > ,
2023-02-17 17:22:00 +00:00
#[ clap(long, env) ]
cors_allow_origin : Option < Vec < String > > ,
2023-06-16 14:25:11 +00:00
#[ clap(long, env) ]
ngrok : bool ,
#[ clap(long, env) ]
ngrok_authtoken : Option < String > ,
#[ clap(long, env) ]
2023-07-19 09:59:58 +00:00
ngrok_edge : Option < String > ,
2022-10-17 16:27:33 +00:00
}
2022-10-08 10:30:12 +00:00
2023-07-10 12:47:15 +00:00
fn main ( ) -> Result < ( ) , RouterError > {
2022-10-17 16:27:33 +00:00
// Get args
let args = Args ::parse ( ) ;
2022-10-18 13:19:03 +00:00
// Pattern match configuration
2022-10-17 16:27:33 +00:00
let Args {
2022-10-18 13:19:03 +00:00
max_concurrent_requests ,
2023-03-09 14:30:54 +00:00
max_best_of ,
2023-02-15 20:56:59 +00:00
max_stop_sequences ,
2023-08-28 09:43:47 +00:00
max_top_n_tokens ,
2022-10-18 13:19:03 +00:00
max_input_length ,
2023-02-15 20:56:59 +00:00
max_total_tokens ,
2023-04-24 15:59:00 +00:00
waiting_served_ratio ,
2023-06-30 17:09:59 +00:00
max_batch_prefill_tokens ,
max_batch_total_tokens ,
2022-10-21 14:40:05 +00:00
max_waiting_tokens ,
2023-07-05 16:28:45 +00:00
hostname ,
2022-10-17 16:27:33 +00:00
port ,
2022-10-18 13:19:03 +00:00
master_shard_uds_path ,
2022-10-17 16:27:33 +00:00
tokenizer_name ,
2023-04-18 14:16:06 +00:00
revision ,
2022-10-18 13:19:03 +00:00
validation_workers ,
2022-11-02 16:29:56 +00:00
json_output ,
2023-02-13 12:02:45 +00:00
otlp_endpoint ,
2023-02-17 17:22:00 +00:00
cors_allow_origin ,
2023-06-16 14:25:11 +00:00
ngrok ,
ngrok_authtoken ,
2023-07-19 09:59:58 +00:00
ngrok_edge ,
2022-10-17 16:27:33 +00:00
} = args ;
2023-06-30 18:07:49 +00:00
// Validate args
2023-07-13 12:22:37 +00:00
if max_input_length > = max_total_tokens {
return Err ( RouterError ::ArgumentValidation (
" `max_input_length` must be < `max_total_tokens` " . to_string ( ) ,
) ) ;
}
2023-06-30 18:07:49 +00:00
if max_input_length as u32 > max_batch_prefill_tokens {
2023-07-13 12:22:37 +00:00
return Err ( RouterError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length} " ) ) ) ;
2023-06-30 18:07:49 +00:00
}
2023-07-19 07:31:25 +00:00
2022-11-02 16:29:56 +00:00
if validation_workers = = 0 {
2023-07-13 12:22:37 +00:00
return Err ( RouterError ::ArgumentValidation (
" `validation_workers` must be > 0 " . to_string ( ) ,
) ) ;
2022-10-18 13:19:03 +00:00
}
2023-07-19 07:31:25 +00:00
if let Some ( ref max_batch_total_tokens ) = max_batch_total_tokens {
if max_batch_prefill_tokens > * max_batch_total_tokens {
return Err ( RouterError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
if max_total_tokens as u32 > * max_batch_total_tokens {
return Err ( RouterError ::ArgumentValidation ( format! ( " `max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
}
2023-02-17 17:22:00 +00:00
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin : Option < AllowOrigin > = cors_allow_origin . map ( | cors_allow_origin | {
AllowOrigin ::list (
cors_allow_origin
. iter ( )
. map ( | origin | origin . parse ::< HeaderValue > ( ) . unwrap ( ) ) ,
)
} ) ;
2023-04-19 18:06:06 +00:00
// Parse Huggingface hub token
let authorization_token = std ::env ::var ( " HUGGING_FACE_HUB_TOKEN " ) . ok ( ) ;
2023-03-06 13:39:36 +00:00
// Tokenizer instance
2022-10-18 13:19:03 +00:00
// This will only be used to validate payloads
2023-03-06 13:39:36 +00:00
let local_path = Path ::new ( & tokenizer_name ) ;
2023-04-18 14:16:06 +00:00
let local_model = local_path . exists ( ) & & local_path . is_dir ( ) ;
let tokenizer = if local_model {
// Load local tokenizer
Tokenizer ::from_file ( local_path . join ( " tokenizer.json " ) ) . ok ( )
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters {
2023-07-13 16:59:38 +00:00
revision : revision . clone ( ) . unwrap_or ( " main " . to_string ( ) ) ,
2023-04-19 18:06:06 +00:00
auth_token : authorization_token . clone ( ) ,
2023-04-18 14:16:06 +00:00
.. Default ::default ( )
2023-03-06 13:39:36 +00:00
} ;
2023-04-18 14:16:06 +00:00
Tokenizer ::from_pretrained ( tokenizer_name . clone ( ) , Some ( params ) ) . ok ( )
} ;
2022-10-11 14:50:54 +00:00
2022-10-18 13:19:03 +00:00
// Launch Tokio runtime
2022-10-11 14:50:54 +00:00
tokio ::runtime ::Builder ::new_multi_thread ( )
. enable_all ( )
2023-07-10 12:47:15 +00:00
. build ( ) ?
2022-10-11 14:50:54 +00:00
. block_on ( async {
2023-03-30 15:28:14 +00:00
init_logging ( otlp_endpoint , json_output ) ;
2023-04-09 18:22:27 +00:00
if tokenizer . is_none ( ) {
tracing ::warn! (
" Could not find a fast tokenizer implementation for {tokenizer_name} "
) ;
tracing ::warn! ( " Rust input length validation and truncation is disabled " ) ;
}
2023-04-18 14:16:06 +00:00
// Get Model info
let model_info = match local_model {
2023-04-21 13:36:29 +00:00
true = > HubModelInfo {
2023-04-18 14:16:06 +00:00
model_id : tokenizer_name . clone ( ) ,
sha : None ,
pipeline_tag : None ,
} ,
2023-07-13 16:59:38 +00:00
false = > get_model_info ( & tokenizer_name , revision , authorization_token )
2023-06-30 17:09:59 +00:00
. await
. unwrap_or_else ( | | {
tracing ::warn! ( " Could not retrieve model info from the Hugging Face hub. " ) ;
HubModelInfo {
model_id : tokenizer_name . to_string ( ) ,
sha : None ,
pipeline_tag : None ,
}
} ) ,
2023-04-18 14:16:06 +00:00
} ;
2023-02-28 09:19:32 +00:00
// if pipeline-tag == text-generation we default to return_full_text = true
2023-04-18 14:16:06 +00:00
let compat_return_full_text = match & model_info . pipeline_tag {
2023-02-28 09:19:32 +00:00
None = > {
tracing ::warn! ( " no pipeline tag found for model {tokenizer_name} " ) ;
false
}
2023-04-18 14:16:06 +00:00
Some ( pipeline_tag ) = > pipeline_tag . as_str ( ) = = " text-generation " ,
2023-02-28 09:19:32 +00:00
} ;
2022-10-18 13:19:03 +00:00
// Instantiate sharded client from the master unix socket
2022-10-22 21:40:05 +00:00
let mut sharded_client = ShardedClient ::connect_uds ( master_shard_uds_path )
2022-10-17 12:59:00 +00:00
. await
2023-07-10 12:47:15 +00:00
. map_err ( RouterError ::Connection ) ? ;
2022-10-18 13:19:03 +00:00
// Clear the cache; useful if the webserver rebooted
2022-10-11 14:50:54 +00:00
sharded_client
2023-03-28 09:29:35 +00:00
. clear_cache ( None )
2022-10-11 14:50:54 +00:00
. await
2023-07-10 12:47:15 +00:00
. map_err ( RouterError ::Cache ) ? ;
2023-04-21 13:36:29 +00:00
// Get info from the shard
2023-07-10 12:47:15 +00:00
let shard_info = sharded_client . info ( ) . await . map_err ( RouterError ::Info ) ? ;
2023-06-30 17:09:59 +00:00
// Warmup model
tracing ::info! ( " Warming up model " ) ;
2023-07-19 07:31:25 +00:00
let max_supported_batch_total_tokens = match sharded_client
2023-10-20 08:28:45 +00:00
. warmup ( max_input_length as u32 , max_batch_prefill_tokens , max_total_tokens as u32 )
2023-06-30 17:09:59 +00:00
. await
2023-07-19 07:31:25 +00:00
. map_err ( RouterError ::Warmup ) ?
{
// Older models do not support automatic max-batch-total-tokens
None = > {
let max_batch_total_tokens = max_batch_total_tokens . unwrap_or (
16000. max ( ( max_total_tokens as u32 ) . max ( max_batch_prefill_tokens ) ) ,
) ;
tracing ::warn! ( " Model does not support automatic max batch total tokens " ) ;
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some ( max_supported_batch_total_tokens ) = > {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens . is_some ( ) {
tracing ::warn! (
" `--max-batch-total-tokens` is deprecated for Flash \
Attention models . "
) ;
tracing ::warn! (
" Inferred max batch total tokens: {max_supported_batch_total_tokens} "
) ;
}
2023-07-28 13:36:38 +00:00
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err ( RouterError ::ArgumentValidation ( format! ( " `max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens} " ) ) ) ;
}
2023-07-19 07:31:25 +00:00
max_supported_batch_total_tokens
}
} ;
tracing ::info! ( " Setting max batch total tokens to {max_supported_batch_total_tokens} " ) ;
2022-10-11 14:50:54 +00:00
tracing ::info! ( " Connected " ) ;
2022-10-08 10:30:12 +00:00
2023-07-05 16:28:45 +00:00
let addr = match hostname . parse ( ) {
Ok ( ip ) = > SocketAddr ::new ( ip , port ) ,
Err ( _ ) = > {
tracing ::warn! ( " Invalid hostname, defaulting to 0.0.0.0 " ) ;
SocketAddr ::new ( IpAddr ::V4 ( Ipv4Addr ::new ( 0 , 0 , 0 , 0 ) ) , port )
}
} ;
2022-10-08 10:30:12 +00:00
2022-10-18 13:19:03 +00:00
// Run server
server ::run (
2023-04-18 14:16:06 +00:00
model_info ,
2023-04-21 13:36:29 +00:00
shard_info ,
2023-02-28 09:19:32 +00:00
compat_return_full_text ,
2022-10-18 13:19:03 +00:00
max_concurrent_requests ,
2023-03-09 14:30:54 +00:00
max_best_of ,
2023-02-15 20:56:59 +00:00
max_stop_sequences ,
2023-08-28 09:43:47 +00:00
max_top_n_tokens ,
2022-10-18 13:19:03 +00:00
max_input_length ,
2023-02-15 20:56:59 +00:00
max_total_tokens ,
2023-04-24 15:59:00 +00:00
waiting_served_ratio ,
2023-06-30 17:09:59 +00:00
max_batch_prefill_tokens ,
2023-07-19 07:31:25 +00:00
max_supported_batch_total_tokens ,
2022-10-21 14:40:05 +00:00
max_waiting_tokens ,
2022-10-18 13:19:03 +00:00
sharded_client ,
tokenizer ,
validation_workers ,
addr ,
2023-02-17 17:22:00 +00:00
cors_allow_origin ,
2023-06-16 14:25:11 +00:00
ngrok ,
ngrok_authtoken ,
2023-07-19 09:59:58 +00:00
ngrok_edge ,
2022-10-18 13:19:03 +00:00
)
2023-07-28 13:36:38 +00:00
. await ? ;
2022-10-11 16:14:39 +00:00
Ok ( ( ) )
2022-10-11 14:50:54 +00:00
} )
2022-10-08 10:30:12 +00:00
}
2023-02-13 12:02:45 +00:00
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging ( otlp_endpoint : Option < String > , json_output : bool ) {
let mut layers = Vec ::new ( ) ;
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber ::fmt ::layer ( )
. with_file ( true )
. with_line_number ( true ) ;
let fmt_layer = match json_output {
true = > fmt_layer . json ( ) . flatten_event ( true ) . boxed ( ) ,
false = > fmt_layer . boxed ( ) ,
} ;
layers . push ( fmt_layer ) ;
// OpenTelemetry tracing layer
if let Some ( otlp_endpoint ) = otlp_endpoint {
global ::set_text_map_propagator ( TraceContextPropagator ::new ( ) ) ;
let tracer = opentelemetry_otlp ::new_pipeline ( )
. tracing ( )
. with_exporter (
opentelemetry_otlp ::new_exporter ( )
. tonic ( )
. with_endpoint ( otlp_endpoint ) ,
)
. with_trace_config (
trace ::config ( )
. with_resource ( Resource ::new ( vec! [ KeyValue ::new (
" service.name " ,
" text-generation-inference.router " ,
) ] ) )
. with_sampler ( Sampler ::AlwaysOn ) ,
)
. install_batch ( opentelemetry ::runtime ::Tokio ) ;
if let Ok ( tracer ) = tracer {
layers . push ( tracing_opentelemetry ::layer ( ) . with_tracer ( tracer ) . boxed ( ) ) ;
2023-09-27 08:40:18 +00:00
init_tracing_opentelemetry ::init_propagator ( ) . unwrap ( ) ;
2023-02-13 12:02:45 +00:00
} ;
}
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter ::try_from_env ( " LOG_LEVEL " ) . unwrap_or_else ( | _ | EnvFilter ::new ( " info " ) ) ;
tracing_subscriber ::registry ( )
. with ( env_filter )
. with ( layers )
. init ( ) ;
}
2023-04-18 14:16:06 +00:00
/// get model info from the Huggingface Hub
2023-05-09 11:19:31 +00:00
pub async fn get_model_info (
model_id : & str ,
2023-07-13 16:59:38 +00:00
revision : Option < String > ,
2023-05-09 11:19:31 +00:00
token : Option < String > ,
) -> Option < HubModelInfo > {
2023-07-13 16:59:38 +00:00
let revision = match revision {
None = > {
tracing ::warn! ( " `--revision` is not set " ) ;
tracing ::warn! ( " We strongly advise to set it to a known supported commit. " ) ;
" main " . to_string ( )
}
Some ( revision ) = > revision ,
} ;
2023-04-19 18:06:06 +00:00
let client = reqwest ::Client ::new ( ) ;
2023-05-04 13:14:28 +00:00
// Poor man's urlencode
2023-05-09 11:19:31 +00:00
let revision = revision . replace ( '/' , " %2F " ) ;
2023-05-04 13:14:28 +00:00
let url = format! ( " https://huggingface.co/api/models/ {model_id} /revision/ {revision} " ) ;
2023-05-09 11:19:31 +00:00
let mut builder = client . get ( url ) . timeout ( Duration ::from_secs ( 5 ) ) ;
2023-04-19 18:06:06 +00:00
if let Some ( token ) = token {
builder = builder . bearer_auth ( token ) ;
}
2023-05-09 11:19:31 +00:00
let response = builder . send ( ) . await . ok ( ) ? ;
if response . status ( ) . is_success ( ) {
2023-07-13 16:59:38 +00:00
let hub_model_info : HubModelInfo =
serde_json ::from_str ( & response . text ( ) . await . ok ( ) ? ) . ok ( ) ? ;
if let Some ( sha ) = & hub_model_info . sha {
tracing ::info! (
" Serving revision {sha} of model {} " ,
hub_model_info . model_id
) ;
}
Some ( hub_model_info )
} else {
None
2023-05-09 11:19:31 +00:00
}
2023-04-18 14:16:06 +00:00
}
2023-07-10 12:47:15 +00:00
#[ derive(Debug, Error) ]
enum RouterError {
2023-07-13 12:22:37 +00:00
#[ error( " Argument validation error: {0} " ) ]
ArgumentValidation ( String ) ,
2023-07-10 12:47:15 +00:00
#[ error( " Unable to connect to the Python model shards: {0} " ) ]
Connection ( ClientError ) ,
#[ error( " Unable to clear the Python model shards cache: {0} " ) ]
Cache ( ClientError ) ,
#[ error( " Unable to get the Python model shards info: {0} " ) ]
Info ( ClientError ) ,
#[ error( " Unable to warmup the Python model shards: {0} " ) ]
Warmup ( ClientError ) ,
#[ error( " Tokio runtime failed to start: {0} " ) ]
Tokio ( #[ from ] std ::io ::Error ) ,
#[ error( " Axum webserver failed: {0} " ) ]
Axum ( #[ from ] axum ::BoxError ) ,
}