mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
added implementation that requires new cli argument
This commit is contained in:
parent
1d1b1efa01
commit
c27075d349
@ -1162,6 +1162,7 @@ fn spawn_webserver(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_prefill_tokens: u32,
|
||||
startup_time: u64,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Child, LauncherError> {
|
||||
@ -1199,6 +1200,8 @@ fn spawn_webserver(
|
||||
format!("{}-0", args.shard_uds_path),
|
||||
"--tokenizer-name".to_string(),
|
||||
args.model_id,
|
||||
"--startup-time".to_string(),
|
||||
startup_time.to_string(),
|
||||
];
|
||||
|
||||
// Grammar support
|
||||
@ -1341,6 +1344,7 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args: Args = Args::parse();
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let varname = "LOG_LEVEL";
|
||||
@ -1622,12 +1626,14 @@ fn main() -> Result<(), LauncherError> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let download_time = start_time.elapsed().as_secs();
|
||||
let mut webserver = spawn_webserver(
|
||||
num_shard,
|
||||
args,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
download_time,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
)
|
||||
|
@ -87,6 +87,8 @@ struct Args {
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
startup_time: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
@ -129,6 +131,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
command,
|
||||
startup_time,
|
||||
} = args;
|
||||
|
||||
let print_schema_command = match command {
|
||||
@ -378,6 +381,8 @@ async fn main() -> Result<(), RouterError> {
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!("start time of the model is {startup_time}");
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
master_shard_uds_path,
|
||||
@ -409,6 +414,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
print_schema_command,
|
||||
startup_time,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
@ -54,6 +54,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
use tokio::time::Duration;
|
||||
|
||||
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
|
||||
#[utoipa::path(
|
||||
@ -1433,6 +1434,7 @@ pub async fn run(
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
print_schema_command: bool,
|
||||
start_time: u64,
|
||||
) -> Result<(), WebServerError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -1512,6 +1514,7 @@ pub async fn run(
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
let length_time = Instant::now();
|
||||
|
||||
// Create state
|
||||
if print_schema_command {
|
||||
@ -1892,6 +1895,12 @@ pub async fn run(
|
||||
.layer(cors_layer);
|
||||
|
||||
tracing::info!("Connected");
|
||||
let total_time = length_time.elapsed() + Duration::from_secs(start_time);
|
||||
tracing::info!("total time for router to boot up and connect to model server {:?}", length_time.elapsed());
|
||||
tracing::info!("the total time in secs of boot time is {:?}", total_time);
|
||||
metrics::gauge!("tgi_model_load_time").set(total_time.as_secs_f64());
|
||||
|
||||
|
||||
|
||||
if ngrok {
|
||||
#[cfg(feature = "ngrok")]
|
||||
|
Loading…
Reference in New Issue
Block a user