This commit is contained in:
Julien DELAVANDE 2025-09-03 09:25:18 +02:00 committed by GitHub
commit 6cbc0115bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 1 deletions

View File

@ -1450,6 +1450,13 @@
},
"nullable": true
},
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
@ -2176,6 +2183,13 @@
"input_length"
],
"properties": {
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},

View File

@ -65,6 +65,7 @@ csv = "1.3.0"
ureq = "=2.9"
pyo3 = { workspace = true }
chrono = "0.4.39"
nvml-wrapper = "0.11.0"
[build-dependencies]

View File

@ -412,6 +412,7 @@ mod tests {
generated_tokens: 10,
seed: None,
finish_reason: FinishReason::Length,
energy_mj: None,
}),
});
if let ChatEvent::Events(events) = events {

View File

@ -23,6 +23,10 @@ use tracing::warn;
use utoipa::ToSchema;
use uuid::Uuid;
use validation::Validation;
use nvml_wrapper::Nvml;
use std::sync::OnceLock;
static NVML: OnceLock<Option<Nvml>> = OnceLock::new();
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
@ -1468,6 +1472,9 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}
#[derive(Serialize, ToSchema)]
@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}
#[derive(Serialize, ToSchema, Clone)]
@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
}
}
pub struct EnergyMonitor;
impl EnergyMonitor {
fn nvml() -> Option<&'static Nvml> {
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
}
pub fn energy_mj(gpu_index: u32) -> Option<u64> {
let nvml = Self::nvml()?;
let device = nvml.device_by_index(gpu_index).ok()?;
device.total_energy_consumption().ok()
}
pub fn total_energy_mj() -> Option<u64> {
let nvml = Self::nvml()?;
let count = nvml.device_count().ok()?;
let mut total = 0;
for i in 0..count {
if let Ok(device) = nvml.device_by_index(i) {
if let Ok(energy) = device.total_energy_consumption() {
total += energy;
}
}
}
Some(total)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -26,7 +26,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
};
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span,
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads.
@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
}
_ => (infer.generate(req).await?, None),
};
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details
let input_length = response._input_length;
@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
energy_mj,
})
}
false => None,
@ -515,6 +523,7 @@ async fn generate_stream_internal(
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs);
@ -590,6 +599,11 @@ async fn generate_stream_internal(
queued,
top_tokens,
} => {
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details
let details = match details {
true => Some(StreamDetails {
@ -597,6 +611,7 @@ async fn generate_stream_internal(
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
input_length,
energy_mj,
}),
false => None,
};