diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 38ccfe29..c1538b88 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -69,7 +69,7 @@ struct Args { #[clap(default_value = "4", long, env)] max_client_batch_size: usize, #[clap(default_value = "on", long, env)] - usage_stats: Option, + usage_stats: Option, } #[derive(Debug, Subcommand)] @@ -125,18 +125,9 @@ async fn main() -> Result<(), RouterError> { }; text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); + let usage_stats_level = usage_stats.unwrap_or(usage_stats::UsageStatsLevel::On); + // Validate args - let usage_stats_level = match usage_stats.as_deref() { - Some("on") => usage_stats::UsageStatsLevel::On, - Some("off") => usage_stats::UsageStatsLevel::Off, - Some("no-stack") => usage_stats::UsageStatsLevel::NoStack, - Some(_) => { - return Err(RouterError::ArgumentValidation( - "`usage_stats_level` must be 'on' 'off' or 'no_stack'".to_string(), - )) - } - None => usage_stats::UsageStatsLevel::On, - }; if max_input_tokens >= max_total_tokens { return Err(RouterError::ArgumentValidation( "`max_input_tokens` must be < `max_total_tokens`".to_string(), diff --git a/router/src/server.rs b/router/src/server.rs index 153d9fdd..e744fb21 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1697,9 +1697,7 @@ pub async fn run( // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); let user_agent = match (usage_stats_level, is_container) { - (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => - { - _ => None + (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => { let reduced_args = usage_stats::Args::new( config.clone(), tokenizer_config.tokenizer_class.clone(), @@ -1723,7 +1721,6 @@ pub async fn run( ); Some(usage_stats::UserAgent::new(reduced_args)) } - _ if !is_container => None, _ => None, }; diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index baf0ac10..0282ac63 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -1,4 +1,5 @@ use crate::config::Config; +use clap::ValueEnum; use csv::ReaderBuilder; use reqwest::header::HeaderMap; use serde::Serialize; @@ -13,7 +14,7 @@ use uuid::Uuid; const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; -#[derive(Copy, Clone, Debug, Serialize)] +#[derive(Copy, Clone, Debug, Serialize, ValueEnum)] pub enum UsageStatsLevel { On, NoStack,