mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Merge 2b6d0742c0
into 356de85c29
This commit is contained in:
commit
6cbc0115bc
@ -1450,6 +1450,13 @@
|
|||||||
},
|
},
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"energy_mj": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 152,
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"$ref": "#/components/schemas/FinishReason"
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
},
|
},
|
||||||
@ -2176,6 +2183,13 @@
|
|||||||
"input_length"
|
"input_length"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"energy_mj": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 152,
|
||||||
|
"nullable": true,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"$ref": "#/components/schemas/FinishReason"
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
},
|
},
|
||||||
|
@ -65,6 +65,7 @@ csv = "1.3.0"
|
|||||||
ureq = "=2.9"
|
ureq = "=2.9"
|
||||||
pyo3 = { workspace = true }
|
pyo3 = { workspace = true }
|
||||||
chrono = "0.4.39"
|
chrono = "0.4.39"
|
||||||
|
nvml-wrapper = "0.11.0"
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -412,6 +412,7 @@ mod tests {
|
|||||||
generated_tokens: 10,
|
generated_tokens: 10,
|
||||||
seed: None,
|
seed: None,
|
||||||
finish_reason: FinishReason::Length,
|
finish_reason: FinishReason::Length,
|
||||||
|
energy_mj: None,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
if let ChatEvent::Events(events) = events {
|
if let ChatEvent::Events(events) = events {
|
||||||
|
@ -23,6 +23,10 @@ use tracing::warn;
|
|||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
use nvml_wrapper::Nvml;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
|
static NVML: OnceLock<Option<Nvml>> = OnceLock::new();
|
||||||
|
|
||||||
#[allow(clippy::large_enum_variant)]
|
#[allow(clippy::large_enum_variant)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -1468,6 +1472,9 @@ pub(crate) struct Details {
|
|||||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub top_tokens: Vec<Vec<Token>>,
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[schema(nullable = true, example = 152)]
|
||||||
|
pub energy_mj: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
|
|||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
#[schema(example = 1)]
|
#[schema(example = 1)]
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[schema(nullable = true, example = 152)]
|
||||||
|
pub energy_mj: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema, Clone)]
|
#[derive(Serialize, ToSchema, Clone)]
|
||||||
@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct EnergyMonitor;
|
||||||
|
|
||||||
|
impl EnergyMonitor {
|
||||||
|
fn nvml() -> Option<&'static Nvml> {
|
||||||
|
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn energy_mj(gpu_index: u32) -> Option<u64> {
|
||||||
|
let nvml = Self::nvml()?;
|
||||||
|
let device = nvml.device_by_index(gpu_index).ok()?;
|
||||||
|
device.total_energy_consumption().ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total_energy_mj() -> Option<u64> {
|
||||||
|
let nvml = Self::nvml()?;
|
||||||
|
let count = nvml.device_count().ok()?;
|
||||||
|
let mut total = 0;
|
||||||
|
for i in 0..count {
|
||||||
|
if let Ok(device) = nvml.device_by_index(i) {
|
||||||
|
if let Ok(energy) = device.total_energy_consumption() {
|
||||||
|
total += energy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -26,7 +26,7 @@ use crate::{
|
|||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
|
||||||
};
|
};
|
||||||
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
|
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
|
|||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
let start_energy = EnergyMonitor::total_energy_mj();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// Do not long ultra long inputs, like image payloads.
|
// Do not long ultra long inputs, like image payloads.
|
||||||
@ -318,6 +319,12 @@ pub(crate) async fn generate_internal(
|
|||||||
_ => (infer.generate(req).await?, None),
|
_ => (infer.generate(req).await?, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let end_energy = EnergyMonitor::total_energy_mj();
|
||||||
|
let energy_mj = match (start_energy, end_energy) {
|
||||||
|
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let input_length = response._input_length;
|
let input_length = response._input_length;
|
||||||
let details = match details {
|
let details = match details {
|
||||||
@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
|
|||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
best_of_sequences,
|
best_of_sequences,
|
||||||
top_tokens: response.top_tokens,
|
top_tokens: response.top_tokens,
|
||||||
|
energy_mj,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None,
|
false => None,
|
||||||
@ -515,6 +523,7 @@ async fn generate_stream_internal(
|
|||||||
impl Stream<Item = Result<StreamResponse, InferError>>,
|
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||||
) {
|
) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
let start_energy = EnergyMonitor::total_energy_mj();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.inputs);
|
tracing::debug!("Input: {}", req.inputs);
|
||||||
@ -590,6 +599,11 @@ async fn generate_stream_internal(
|
|||||||
queued,
|
queued,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
|
let end_energy = EnergyMonitor::total_energy_mj();
|
||||||
|
let energy_mj = match (start_energy, end_energy) {
|
||||||
|
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(StreamDetails {
|
true => Some(StreamDetails {
|
||||||
@ -597,6 +611,7 @@ async fn generate_stream_internal(
|
|||||||
generated_tokens: generated_text.generated_tokens,
|
generated_tokens: generated_text.generated_tokens,
|
||||||
seed: generated_text.seed,
|
seed: generated_text.seed,
|
||||||
input_length,
|
input_length,
|
||||||
|
energy_mj,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user