This commit is contained in:
Julien DELAVANDE 2025-09-03 09:25:18 +02:00 committed by GitHub
commit 6cbc0115bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 1 deletions

View File

@ -1450,6 +1450,13 @@
}, },
"nullable": true "nullable": true
}, },
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": { "finish_reason": {
"$ref": "#/components/schemas/FinishReason" "$ref": "#/components/schemas/FinishReason"
}, },
@ -2176,6 +2183,13 @@
"input_length" "input_length"
], ],
"properties": { "properties": {
"energy_mj": {
"type": "integer",
"format": "int64",
"example": 152,
"nullable": true,
"minimum": 0
},
"finish_reason": { "finish_reason": {
"$ref": "#/components/schemas/FinishReason" "$ref": "#/components/schemas/FinishReason"
}, },

View File

@ -65,6 +65,7 @@ csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { workspace = true } pyo3 = { workspace = true }
chrono = "0.4.39" chrono = "0.4.39"
nvml-wrapper = "0.11.0"
[build-dependencies] [build-dependencies]

View File

@ -412,6 +412,7 @@ mod tests {
generated_tokens: 10, generated_tokens: 10,
seed: None, seed: None,
finish_reason: FinishReason::Length, finish_reason: FinishReason::Length,
energy_mj: None,
}), }),
}); });
if let ChatEvent::Events(events) = events { if let ChatEvent::Events(events) = events {

View File

@ -23,6 +23,10 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use uuid::Uuid; use uuid::Uuid;
use validation::Validation; use validation::Validation;
use nvml_wrapper::Nvml;
use std::sync::OnceLock;
static NVML: OnceLock<Option<Nvml>> = OnceLock::new();
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Clone)] #[derive(Clone)]
@ -1468,6 +1472,9 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>, pub top_tokens: Vec<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
pub seed: Option<u64>, pub seed: Option<u64>,
#[schema(example = 1)] #[schema(example = 1)]
pub input_length: u32, pub input_length: u32,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
} }
#[derive(Serialize, ToSchema, Clone)] #[derive(Serialize, ToSchema, Clone)]
@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
} }
} }
pub struct EnergyMonitor;
impl EnergyMonitor {
fn nvml() -> Option<&'static Nvml> {
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
}
pub fn energy_mj(gpu_index: u32) -> Option<u64> {
let nvml = Self::nvml()?;
let device = nvml.device_by_index(gpu_index).ok()?;
device.total_energy_consumption().ok()
}
pub fn total_energy_mj() -> Option<u64> {
let nvml = Self::nvml()?;
let count = nvml.device_count().ok()?;
let mut total = 0;
for i in 0..count {
if let Ok(device) = nvml.device_by_index(i) {
if let Ok(energy) = device.total_energy_consumption() {
total += energy;
}
}
}
Some(total)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -26,7 +26,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
}; };
use crate::{ChatTokenizeResponse, JsonSchemaConfig}; use crate::{ChatTokenizeResponse, JsonSchemaConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span, span: tracing::Span,
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads. // Do not long ultra long inputs, like image payloads.
@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
} }
_ => (infer.generate(req).await?, None), _ => (infer.generate(req).await?, None),
}; };
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details // Token details
let input_length = response._input_length; let input_length = response._input_length;
@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens, top_tokens: response.top_tokens,
energy_mj,
}) })
} }
false => None, false => None,
@ -515,6 +523,7 @@ async fn generate_stream_internal(
impl Stream<Item = Result<StreamResponse, InferError>>, impl Stream<Item = Result<StreamResponse, InferError>>,
) { ) {
let start_time = Instant::now(); let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs); tracing::debug!("Input: {}", req.inputs);
@ -590,6 +599,11 @@ async fn generate_stream_internal(
queued, queued,
top_tokens, top_tokens,
} => { } => {
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details // Token details
let details = match details { let details = match details {
true => Some(StreamDetails { true => Some(StreamDetails {
@ -597,6 +611,7 @@ async fn generate_stream_internal(
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed, seed: generated_text.seed,
input_length, input_length,
energy_mj,
}), }),
false => None, false => None,
}; };