added implementation that requires new cli argument

This commit is contained in:
Edwinhr716 2024-07-25 22:15:27 +00:00
parent 1d1b1efa01
commit c27075d349
3 changed files with 21 additions and 0 deletions

View File

@ -1162,6 +1162,7 @@ fn spawn_webserver(
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
startup_time: u64,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> { ) -> Result<Child, LauncherError> {
@ -1199,6 +1200,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path), format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
args.model_id, args.model_id,
"--startup-time".to_string(),
startup_time.to_string(),
]; ];
// Grammar support // Grammar support
@ -1341,6 +1344,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
// Pattern match configuration // Pattern match configuration
let args: Args = Args::parse(); let args: Args = Args::parse();
let start_time = Instant::now();
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let varname = "LOG_LEVEL"; let varname = "LOG_LEVEL";
@ -1622,12 +1626,14 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let download_time = start_time.elapsed().as_secs();
let mut webserver = spawn_webserver( let mut webserver = spawn_webserver(
num_shard, num_shard,
args, args,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
max_batch_prefill_tokens, max_batch_prefill_tokens,
download_time,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
) )

View File

@ -87,6 +87,8 @@ struct Args {
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(long, env)]
startup_time: u64,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -129,6 +131,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
command, command,
startup_time,
} = args; } = args;
let print_schema_command = match command { let print_schema_command = match command {
@ -378,6 +381,8 @@ async fn main() -> Result<(), RouterError> {
} }
}; };
tracing::info!("start time of the model is {startup_time}");
// Run server // Run server
server::run( server::run(
master_shard_uds_path, master_shard_uds_path,
@ -409,6 +414,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
print_schema_command, print_schema_command,
startup_time,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -54,6 +54,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use tokio::time::Duration;
/// Generate tokens if `stream == false` or a stream of token if `stream == true` /// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path( #[utoipa::path(
@ -1433,6 +1434,7 @@ pub async fn run(
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
print_schema_command: bool, print_schema_command: bool,
start_time: u64,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -1512,6 +1514,7 @@ pub async fn run(
) )
)] )]
struct ApiDoc; struct ApiDoc;
let length_time = Instant::now();
// Create state // Create state
if print_schema_command { if print_schema_command {
@ -1892,6 +1895,12 @@ pub async fn run(
.layer(cors_layer); .layer(cors_layer);
tracing::info!("Connected"); tracing::info!("Connected");
let total_time = length_time.elapsed() + Duration::from_secs(start_time);
tracing::info!("total time for router to boot up and connect to model server {:?}", length_time.elapsed());
tracing::info!("the total time in secs of boot time is {:?}", total_time);
metrics::gauge!("tgi_model_load_time").set(total_time.as_secs_f64());
if ngrok { if ngrok {
#[cfg(feature = "ngrok")] #[cfg(feature = "ngrok")]