From 4c8bf7f5b84144565eac2b268430446b714eb475 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Fri, 24 Jan 2025 18:10:12 +0100 Subject: [PATCH] fix: add telemetry regular pings and fix unhandled errors avoid not sending telemetry stop events. --- router/src/server.rs | 55 +++++++++++++++++++++++++++++++-------- router/src/usage_stats.rs | 3 ++- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index aef0f812..9ab415d3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -54,6 +54,9 @@ use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use tokio::select; use tokio::signal; @@ -1819,9 +1822,9 @@ pub async fn run( HubTokenizerConfig::default() }); - let tokenizer: Tokenizer = { + let tokenizer: Result = { use pyo3::prelude::*; - pyo3::Python::with_gil(|py| -> PyResult<()> { + match Python::with_gil(|py| -> PyResult<()> { py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; Ok(()) }) @@ -1831,17 +1834,23 @@ pub async fn run( .or_else(|err| { let out = legacy_tokenizer_handle(config_filename.as_ref()); out.ok_or(err) - }) - .expect("We cannot load a tokenizer"); + }) { + Ok(_) => {} + Err(_) => { + return Err(WebServerError::Tokenizer( + "Unable to load tokenizer.".to_string(), + )); + } + } let filename = "out/tokenizer.json"; if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { - Tokenizer::Rust(tok) + Ok(Tokenizer::Rust(tok)) } else { - Tokenizer::Python { + Ok(Tokenizer::Python { tokenizer_name: tokenizer_name.clone(), revision: revision.clone(), trust_remote_code, - } + }) } }; @@ -1901,11 +1910,27 @@ pub async fn run( _ => None, }; - if let Some(ref ua) = user_agent { + let stop_usage_thread = Arc::new(AtomicBool::new(false)); + let stop_usage_thread_clone = stop_usage_thread.clone(); + if let Some(ua) = user_agent.clone() { let start_event = usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); tokio::spawn(async move { + // send start event start_event.send().await; + let mut last_report = Instant::now(); + while !stop_usage_thread_clone.load(Ordering::Relaxed) { + if last_report.elapsed() > Duration::from_secs(3600) { + let report_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Ping, + None, + ); + report_event.send().await; + last_report = Instant::now(); + } + tokio::time::sleep(Duration::from_secs(1)).await; + } }); }; let compat_return_full_text = match &model_info.pipeline_tag { @@ -1926,7 +1951,7 @@ pub async fn run( validation_workers, api_key, config, - (tokenizer, tokenizer_config), + (tokenizer?, tokenizer_config), (preprocessor_config, processor_config), hostname, port, @@ -1943,6 +1968,7 @@ pub async fn run( .await; if let Some(ua) = user_agent { + stop_usage_thread.store(true, Ordering::Relaxed); match result { Ok(_) => { let stop_event = usage_stats::UsageStatsEvent::new( @@ -2419,8 +2445,13 @@ async fn start( } } else { // Run server - - let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + let listener = match tokio::net::TcpListener::bind(&addr).await { + Ok(listener) => listener, + Err(e) => { + tracing::error!("Failed to bind to {addr}: {e}"); + return Err(WebServerError::Axum(Box::new(e))); + } + }; axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) .await @@ -2535,4 +2566,6 @@ impl From for Event { pub enum WebServerError { #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), + #[error("Tokenizer error: {0}")] + Tokenizer(String), } diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs index e9d98327..4139c4c5 100644 --- a/router/src/usage_stats.rs +++ b/router/src/usage_stats.rs @@ -43,6 +43,7 @@ pub enum EventType { Start, Stop, Error, + Ping, } #[derive(Debug, Serialize)] @@ -70,7 +71,7 @@ impl UsageStatsEvent { .post(TELEMETRY_URL) .headers(headers) .body(body) - .timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(10)) .send() .await; }