diff --git a/router/src/lib.rs b/router/src/lib.rs index fbbca8bb..e0336190 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -57,7 +57,7 @@ pub enum ChatTemplateVersions { use std::path::Path; -#[derive(Debug, Clone, Serialize ,Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, diff --git a/router/src/main.rs b/router/src/main.rs index 323a8742..9feda62d 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,6 +14,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; +use text_generation_router::usage_stats; use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; @@ -23,7 +24,6 @@ use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; -use text_generation_router::usage_stats; /// App Configuration #[derive(Parser, Debug)] @@ -137,7 +137,7 @@ async fn main() -> Result<(), RouterError> { disable_crash_reports, command, } = args; - + let print_schema_command = match command { Some(Commands::PrintSchema) => true, None => { @@ -411,18 +411,18 @@ async fn main() -> Result<(), RouterError> { disable_crash_reports, ); Some(usage_stats::UserAgent::new(reducded_args)) - } - else { + } else { None }; - + if let Some(ref ua) = user_agent { - let start_event = usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start); + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start); tokio::spawn(async move { start_event.send().await; }); - }; - + }; + // Run server let result = server::run( master_shard_uds_path, @@ -456,11 +456,12 @@ async fn main() -> Result<(), RouterError> { print_schema_command, ) .await; - + match result { Ok(_) => { if let Some(ref ua) = user_agent { - let stop_event = usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Stop); + let stop_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Stop); stop_event.send().await; }; Ok(()) @@ -468,7 +469,10 @@ async fn main() -> Result<(), RouterError> { Err(e) => { if let Some(ref ua) = user_agent { if !disable_crash_reports { - let error_event = usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Error(e.to_string())); + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error(e.to_string()), + ); error_event.send().await; } }; diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index 8bb6ddc1..47aee6f3 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -1,8 +1,8 @@ use crate::config::Config; use reqwest::header::HeaderMap; use serde::Serialize; -use uuid::Uuid; use std::{fmt, process::Command, time::Duration}; +use uuid::Uuid; const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; @@ -24,7 +24,7 @@ impl UserAgent { } #[derive(Serialize, Debug)] -pub enum EventType { +pub enum EventType { Start, Stop, Error(String), @@ -48,7 +48,8 @@ impl UsageStatsEvent { headers.insert("Content-Type", "application/json".parse().unwrap()); let body = serde_json::to_string(&self).unwrap(); let client = reqwest::Client::new(); - let _ = client.post(TELEMETRY_URL) + let _ = client + .post(TELEMETRY_URL) .body(body) .timeout(Duration::from_secs(5)) .send() @@ -56,7 +57,6 @@ impl UsageStatsEvent { } } - #[derive(Debug, Clone, Serialize)] pub struct Args { model_config: Option, @@ -164,7 +164,8 @@ impl SystemInfo { let cpu_type = system.cpus()[0].brand().to_string(); let total_memory = system.total_memory(); let architecture = std::env::consts::ARCH.to_string(); - let platform = format!("{}-{}-{}", + let platform = format!( + "{}-{}-{}", std::env::consts::OS, std::env::consts::FAMILY, std::env::consts::ARCH