diff --git a/router/Cargo.toml b/router/Cargo.toml index 9326258d..99d06ca0 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -65,6 +65,7 @@ csv = "1.3.0" ureq = "=2.9" pyo3 = { workspace = true } chrono = "0.4.39" +nvml-wrapper = "0.11.0" [build-dependencies] diff --git a/router/src/chat.rs b/router/src/chat.rs index d5824fea..93165c29 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -412,6 +412,7 @@ mod tests { generated_tokens: 10, seed: None, finish_reason: FinishReason::Length, + energy_mj: None, }), }); if let ChatEvent::Events(events) = events { diff --git a/router/src/lib.rs b/router/src/lib.rs index e5622fc2..c152b6cd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -23,6 +23,10 @@ use tracing::warn; use utoipa::ToSchema; use uuid::Uuid; use validation::Validation; +use nvml_wrapper::Nvml; +use std::sync::OnceLock; + +static NVML: OnceLock> = OnceLock::new(); #[allow(clippy::large_enum_variant)] #[derive(Clone)] @@ -1468,6 +1472,9 @@ pub(crate) struct Details { pub best_of_sequences: Option>, #[serde(skip_serializing_if = "Vec::is_empty")] pub top_tokens: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = true, example = 152)] + pub energy_mj: Option, } #[derive(Serialize, ToSchema)] @@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails { pub seed: Option, #[schema(example = 1)] pub input_length: u32, + #[serde(skip_serializing_if = "Option::is_none")] + #[schema(nullable = true, example = 152)] + pub energy_mj: Option, } #[derive(Serialize, ToSchema, Clone)] @@ -1546,6 +1556,34 @@ impl Default for ModelsInfo { } } +pub struct EnergyMonitor; + +impl EnergyMonitor { + fn nvml() -> Option<&'static Nvml> { + NVML.get_or_init(|| Nvml::init().ok()).as_ref() + } + + pub fn energy_mj(gpu_index: u32) -> Option { + let nvml = Self::nvml()?; + let device = nvml.device_by_index(gpu_index).ok()?; + device.total_energy_consumption().ok() + } + + pub fn total_energy_mj() -> Option { + let nvml = Self::nvml()?; + let count = nvml.device_count().ok()?; + let mut total = 0; + for i in 0..count { + if let Ok(device) = nvml.device_by_index(i) { + if let Ok(energy) = device.total_energy_consumption() { + total += energy; + } + } + } + Some(total) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index 5fbe0403..51f9317e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -26,7 +26,7 @@ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, - CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, + CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor, }; use crate::{ChatTokenizeResponse, JsonSchemaConfig}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; @@ -293,6 +293,7 @@ pub(crate) async fn generate_internal( span: tracing::Span, ) -> Result<(HeaderMap, u32, Json), (StatusCode, Json)> { let start_time = Instant::now(); + let start_energy = EnergyMonitor::total_energy_mj(); metrics::counter!("tgi_request_count").increment(1); // Do not long ultra long inputs, like image payloads. @@ -317,6 +318,12 @@ pub(crate) async fn generate_internal( } _ => (infer.generate(req).await?, None), }; + + let end_energy = EnergyMonitor::total_energy_mj(); + let energy_mj = match (start_energy, end_energy) { + (Some(start), Some(end)) => Some(end.saturating_sub(start)), + _ => None, + }; // Token details let input_length = response._input_length; @@ -354,6 +361,7 @@ pub(crate) async fn generate_internal( seed: response.generated_text.seed, best_of_sequences, top_tokens: response.top_tokens, + energy_mj, }) } false => None, @@ -515,6 +523,7 @@ async fn generate_stream_internal( impl Stream>, ) { let start_time = Instant::now(); + let start_energy = EnergyMonitor::total_energy_mj(); metrics::counter!("tgi_request_count").increment(1); tracing::debug!("Input: {}", req.inputs); @@ -590,6 +599,11 @@ async fn generate_stream_internal( queued, top_tokens, } => { + let end_energy = EnergyMonitor::total_energy_mj(); + let energy_mj = match (start_energy, end_energy) { + (Some(start), Some(end)) => Some(end.saturating_sub(start)), + _ => None, + }; // Token details let details = match details { true => Some(StreamDetails { @@ -597,6 +611,7 @@ async fn generate_stream_internal( generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, input_length, + energy_mj, }), false => None, };