This commit is contained in:
Edwin Hernandez 2024-07-29 11:22:36 -05:00 committed by GitHub
commit 1246e2193f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 0 deletions

View File

@ -1178,6 +1178,7 @@ fn spawn_webserver(
max_input_tokens: usize, max_input_tokens: usize,
max_total_tokens: usize, max_total_tokens: usize,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
download_time: u64,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> { ) -> Result<Child, LauncherError> {
@ -1304,6 +1305,8 @@ fn spawn_webserver(
envs.push(("COMPUTE_TYPE".into(), compute_type.into())) envs.push(("COMPUTE_TYPE".into(), compute_type.into()))
} }
envs.push(("DOWNLOAD_TIME".into(), download_time.to_string().into()));
let mut webserver = match Command::new("text-generation-router") let mut webserver = match Command::new("text-generation-router")
.args(router_args) .args(router_args)
.envs(envs) .envs(envs)
@ -1370,6 +1373,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
// Pattern match configuration // Pattern match configuration
let args: Args = Args::parse(); let args: Args = Args::parse();
let start_time = Instant::now();
// Filter events with LOG_LEVEL // Filter events with LOG_LEVEL
let varname = "LOG_LEVEL"; let varname = "LOG_LEVEL";
@ -1666,12 +1670,14 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let download_time = start_time.elapsed().as_secs();
let mut webserver = spawn_webserver( let mut webserver = spawn_webserver(
num_shard, num_shard,
args, args,
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
max_batch_prefill_tokens, max_batch_prefill_tokens,
download_time,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,
) )

View File

@ -55,6 +55,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use tokio::time::Duration;
/// Generate tokens if `stream == false` or a stream of token if `stream == true` /// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path( #[utoipa::path(
@ -1509,6 +1510,8 @@ pub async fn run(
) )
)] )]
struct ApiDoc; struct ApiDoc;
let download_time = std::env::var("DOWNLOAD_TIME").unwrap_or("30".to_string()).parse::<u64>().unwrap_or(30);
let length_time = Instant::now();
// Create state // Create state
if print_schema_command { if print_schema_command {
@ -1916,6 +1919,10 @@ pub async fn run(
.layer(cors_layer); .layer(cors_layer);
tracing::info!("Connected"); tracing::info!("Connected");
let total_time = length_time.elapsed() + Duration::from_secs(download_time);
metrics::gauge!("tgi_model_load_time").set(total_time.as_secs_f64());
if ngrok { if ngrok {
#[cfg(feature = "ngrok")] #[cfg(feature = "ngrok")]