2024-08-14 10:02:05 +00:00
use clap ::Parser ;
2024-07-31 08:33:10 +00:00
use std ::collections ::HashMap ;
use std ::path ::PathBuf ;
use text_generation_backends_trtllm ::errors ::TensorRtLlmBackendError ;
use text_generation_backends_trtllm ::TensorRtLlmBackend ;
use text_generation_router ::server ;
2024-08-14 10:02:05 +00:00
use tokenizers ::{ FromPretrainedParameters , Tokenizer } ;
2024-07-31 08:33:10 +00:00
/// App Configuration
#[ derive(Parser, Debug) ]
#[ clap(author, version, about, long_about = None) ]
struct Args {
#[ clap(default_value = " 128 " , long, env) ]
max_concurrent_requests : usize ,
#[ clap(default_value = " 2 " , long, env) ]
max_best_of : usize ,
#[ clap(default_value = " 4 " , long, env) ]
max_stop_sequences : usize ,
#[ clap(default_value = " 5 " , long, env) ]
max_top_n_tokens : u32 ,
#[ clap(default_value = " 1024 " , long, env) ]
max_input_tokens : usize ,
#[ clap(default_value = " 2048 " , long, env) ]
max_total_tokens : usize ,
#[ clap(default_value = " 4096 " , long, env) ]
max_batch_prefill_tokens : u32 ,
#[ clap(long, env) ]
max_batch_total_tokens : Option < u32 > ,
#[ clap(default_value = " 0.0.0.0 " , long, env) ]
hostname : String ,
#[ clap(default_value = " 3000 " , long, short, env) ]
port : u16 ,
#[ clap(long, env, required = true) ]
tokenizer_name : String ,
#[ clap(long, env) ]
tokenizer_config_path : Option < String > ,
#[ clap(long, env) ]
revision : Option < String > ,
#[ clap(long, env) ]
model_id : String ,
#[ clap(default_value = " 2 " , long, env) ]
validation_workers : usize ,
#[ clap(long, env) ]
json_output : bool ,
#[ clap(long, env) ]
otlp_endpoint : Option < String > ,
#[ clap(default_value = " text-generation-inference.router " , long, env) ]
otlp_service_name : String ,
#[ clap(long, env) ]
cors_allow_origin : Option < Vec < String > > ,
#[ clap(long, env, default_value_t = false) ]
messages_api_enabled : bool ,
#[ clap(default_value = " 4 " , long, env) ]
max_client_batch_size : usize ,
#[ clap(long, env) ]
auth_token : Option < String > ,
#[ clap(long, env, help = " Path to the TensorRT-LLM Orchestrator worker " ) ]
executor_worker : PathBuf ,
}
#[ tokio::main ]
async fn main ( ) -> Result < ( ) , TensorRtLlmBackendError > {
// Get args
let args = Args ::parse ( ) ;
// Pattern match configuration
let Args {
max_concurrent_requests ,
max_best_of ,
max_stop_sequences ,
max_top_n_tokens ,
max_input_tokens ,
max_total_tokens ,
max_batch_prefill_tokens ,
max_batch_total_tokens ,
hostname ,
port ,
tokenizer_name ,
tokenizer_config_path ,
revision ,
model_id ,
validation_workers ,
json_output ,
otlp_endpoint ,
otlp_service_name ,
cors_allow_origin ,
messages_api_enabled ,
max_client_batch_size ,
auth_token ,
executor_worker ,
} = args ;
// Launch Tokio runtime
text_generation_router ::logging ::init_logging ( otlp_endpoint , otlp_service_name , json_output ) ;
// Validate args
if max_input_tokens > = max_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation (
" `max_input_tokens` must be < `max_total_tokens` " . to_string ( ) ,
) ) ;
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens} " ) ) ) ;
}
if validation_workers = = 0 {
return Err ( TensorRtLlmBackendError ::ArgumentValidation (
" `validation_workers` must be > 0 " . to_string ( ) ,
) ) ;
}
if let Some ( ref max_batch_total_tokens ) = max_batch_total_tokens {
if max_batch_prefill_tokens > * max_batch_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
if max_total_tokens as u32 > * max_batch_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
}
if ! executor_worker . exists ( ) {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! (
" `executor_work` specified path doesn't exists: {} " ,
executor_worker . display ( )
) ) ) ;
}
// Run server
let tokenizer = Tokenizer ::from_pretrained (
tokenizer_name . clone ( ) ,
Some ( FromPretrainedParameters {
revision : revision . clone ( ) . unwrap_or ( String ::from ( " main " ) ) ,
user_agent : HashMap ::new ( ) ,
auth_token ,
} ) ,
)
. map_err ( | e | TensorRtLlmBackendError ::Tokenizer ( e . to_string ( ) ) ) ? ;
let backend = TensorRtLlmBackend ::new ( tokenizer , model_id , executor_worker ) ? ;
server ::run (
backend ,
max_concurrent_requests ,
max_best_of ,
max_stop_sequences ,
max_top_n_tokens ,
max_input_tokens ,
max_total_tokens ,
validation_workers ,
None ,
tokenizer_name ,
tokenizer_config_path ,
revision ,
hostname ,
port ,
cors_allow_origin ,
false ,
None ,
None ,
messages_api_enabled ,
true ,
max_client_batch_size ,
2024-08-14 10:02:05 +00:00
false ,
false ,
2024-07-31 08:33:10 +00:00
)
. await ? ;
Ok ( ( ) )
}