mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-08 19:04:52 +00:00
Merge 2b6d0742c0
into 356de85c29
This commit is contained in:
commit
6cbc0115bc
@ -1450,6 +1450,13 @@
|
||||
},
|
||||
"nullable": true
|
||||
},
|
||||
"energy_mj": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 152,
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"finish_reason": {
|
||||
"$ref": "#/components/schemas/FinishReason"
|
||||
},
|
||||
@ -2176,6 +2183,13 @@
|
||||
"input_length"
|
||||
],
|
||||
"properties": {
|
||||
"energy_mj": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 152,
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"finish_reason": {
|
||||
"$ref": "#/components/schemas/FinishReason"
|
||||
},
|
||||
|
@ -65,6 +65,7 @@ csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { workspace = true }
|
||||
chrono = "0.4.39"
|
||||
nvml-wrapper = "0.11.0"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -412,6 +412,7 @@ mod tests {
|
||||
generated_tokens: 10,
|
||||
seed: None,
|
||||
finish_reason: FinishReason::Length,
|
||||
energy_mj: None,
|
||||
}),
|
||||
});
|
||||
if let ChatEvent::Events(events) = events {
|
||||
|
@ -23,6 +23,10 @@ use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use uuid::Uuid;
|
||||
use validation::Validation;
|
||||
use nvml_wrapper::Nvml;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static NVML: OnceLock<Option<Nvml>> = OnceLock::new();
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone)]
|
||||
@ -1468,6 +1472,9 @@ pub(crate) struct Details {
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(nullable = true, example = 152)]
|
||||
pub energy_mj: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
|
||||
pub seed: Option<u64>,
|
||||
#[schema(example = 1)]
|
||||
pub input_length: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(nullable = true, example = 152)]
|
||||
pub energy_mj: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema, Clone)]
|
||||
@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EnergyMonitor;
|
||||
|
||||
impl EnergyMonitor {
|
||||
fn nvml() -> Option<&'static Nvml> {
|
||||
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
|
||||
}
|
||||
|
||||
pub fn energy_mj(gpu_index: u32) -> Option<u64> {
|
||||
let nvml = Self::nvml()?;
|
||||
let device = nvml.device_by_index(gpu_index).ok()?;
|
||||
device.total_energy_consumption().ok()
|
||||
}
|
||||
|
||||
pub fn total_energy_mj() -> Option<u64> {
|
||||
let nvml = Self::nvml()?;
|
||||
let count = nvml.device_count().ok()?;
|
||||
let mut total = 0;
|
||||
for i in 0..count {
|
||||
if let Ok(device) = nvml.device_by_index(i) {
|
||||
if let Ok(energy) = device.total_energy_consumption() {
|
||||
total += energy;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(total)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -26,7 +26,7 @@ use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
|
||||
};
|
||||
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||
@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
|
||||
span: tracing::Span,
|
||||
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let start_time = Instant::now();
|
||||
let start_energy = EnergyMonitor::total_energy_mj();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
|
||||
}
|
||||
_ => (infer.generate(req).await?, None),
|
||||
};
|
||||
|
||||
let end_energy = EnergyMonitor::total_energy_mj();
|
||||
let energy_mj = match (start_energy, end_energy) {
|
||||
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Token details
|
||||
let input_length = response._input_length;
|
||||
@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
|
||||
seed: response.generated_text.seed,
|
||||
best_of_sequences,
|
||||
top_tokens: response.top_tokens,
|
||||
energy_mj,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
@ -515,6 +523,7 @@ async fn generate_stream_internal(
|
||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
let start_energy = EnergyMonitor::total_energy_mj();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
@ -590,6 +599,11 @@ async fn generate_stream_internal(
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
let end_energy = EnergyMonitor::total_energy_mj();
|
||||
let energy_mj = match (start_energy, end_energy) {
|
||||
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
|
||||
_ => None,
|
||||
};
|
||||
// Token details
|
||||
let details = match details {
|
||||
true => Some(StreamDetails {
|
||||
@ -597,6 +611,7 @@ async fn generate_stream_internal(
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
input_length,
|
||||
energy_mj,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user