changes based on feedback

This commit is contained in:
erikkaum 2024-07-31 11:55:09 +02:00
parent 00478579e3
commit 4a0fdad1a7
3 changed files with 6 additions and 17 deletions

View File

@ -69,7 +69,7 @@ struct Args {
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: Option<String>, usage_stats: Option<usage_stats::UsageStatsLevel>,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -125,18 +125,9 @@ async fn main() -> Result<(), RouterError> {
}; };
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
let usage_stats_level = usage_stats.unwrap_or(usage_stats::UsageStatsLevel::On);
// Validate args // Validate args
let usage_stats_level = match usage_stats.as_deref() {
Some("on") => usage_stats::UsageStatsLevel::On,
Some("off") => usage_stats::UsageStatsLevel::Off,
Some("no-stack") => usage_stats::UsageStatsLevel::NoStack,
Some(_) => {
return Err(RouterError::ArgumentValidation(
"`usage_stats_level` must be 'on' 'off' or 'no_stack'".to_string(),
))
}
None => usage_stats::UsageStatsLevel::On,
};
if max_input_tokens >= max_total_tokens { if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),

View File

@ -1697,9 +1697,7 @@ pub async fn run(
// Only send usage stats when TGI is run in container and the function returns Some // Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true)); let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = match (usage_stats_level, is_container) { let user_agent = match (usage_stats_level, is_container) {
(usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => {
{
_ => None
let reduced_args = usage_stats::Args::new( let reduced_args = usage_stats::Args::new(
config.clone(), config.clone(),
tokenizer_config.tokenizer_class.clone(), tokenizer_config.tokenizer_class.clone(),
@ -1723,7 +1721,6 @@ pub async fn run(
); );
Some(usage_stats::UserAgent::new(reduced_args)) Some(usage_stats::UserAgent::new(reduced_args))
} }
_ if !is_container => None,
_ => None, _ => None,
}; };

View File

@ -1,4 +1,5 @@
use crate::config::Config; use crate::config::Config;
use clap::ValueEnum;
use csv::ReaderBuilder; use csv::ReaderBuilder;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use serde::Serialize; use serde::Serialize;
@ -13,7 +14,7 @@ use uuid::Uuid;
const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi";
#[derive(Copy, Clone, Debug, Serialize)] #[derive(Copy, Clone, Debug, Serialize, ValueEnum)]
pub enum UsageStatsLevel { pub enum UsageStatsLevel {
On, On,
NoStack, NoStack,