mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Attempt at automatic max batch prefill.
This commit is contained in:
parent
b57f370386
commit
54d3c8157c
@ -29,6 +29,26 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
|||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
mod gpu;
|
mod gpu;
|
||||||
|
|
||||||
|
fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option<usize> {
|
||||||
|
if let (Some(config), Some(compute)) = (config, compute) {
|
||||||
|
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
|
||||||
|
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
|
||||||
|
let optimal_size = (f16_max_compute / model_compute) as usize;
|
||||||
|
if optimal_size > 100 {
|
||||||
|
// Ignore calculations that's too low
|
||||||
|
// Most likely an error
|
||||||
|
Some(optimal_size)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_config(
|
fn get_config(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
revision: &Option<String>,
|
revision: &Option<String>,
|
||||||
@ -144,7 +164,10 @@ struct RawConfig {
|
|||||||
quantization_config: Option<QuantizationConfig>,
|
quantization_config: Option<QuantizationConfig>,
|
||||||
n_embd: Option<usize>,
|
n_embd: Option<usize>,
|
||||||
hidden_size: Option<usize>,
|
hidden_size: Option<usize>,
|
||||||
|
intermediate_size: Option<usize>,
|
||||||
num_attention_heads: Option<usize>,
|
num_attention_heads: Option<usize>,
|
||||||
|
num_key_value_heads: Option<usize>,
|
||||||
|
num_hidden_layers: Option<usize>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
is_encoder_decoder: Option<bool>,
|
is_encoder_decoder: Option<bool>,
|
||||||
@ -155,19 +178,42 @@ struct QuantizationConfig {
|
|||||||
quant_method: Option<Quantization>,
|
quant_method: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct VisionConfig {}
|
struct VisionConfig {}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
|
num_heads: Option<usize>,
|
||||||
|
num_kv_heads: Option<usize>,
|
||||||
|
num_layers: Option<usize>,
|
||||||
|
intermediate_size: Option<usize>,
|
||||||
|
hidden_size: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn flop(&self) -> Option<u64> {
|
||||||
|
let num_heads = self.num_heads? as u64;
|
||||||
|
let num_kv_heads = self.num_kv_heads? as u64;
|
||||||
|
let head_dim = self.head_dim? as u64;
|
||||||
|
let hidden_size = self.hidden_size? as u64;
|
||||||
|
let intermediate_size = self.intermediate_size? as u64;
|
||||||
|
let num_layers = self.num_layers? as u64;
|
||||||
|
|
||||||
|
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
|
||||||
|
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||||
|
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
||||||
|
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
|
||||||
|
let total = layer_flops * num_layers;
|
||||||
|
Some(total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
fn from(other: RawConfig) -> Self {
|
fn from(other: RawConfig) -> Self {
|
||||||
let max_position_embeddings = other
|
let max_position_embeddings = other
|
||||||
@ -175,22 +221,21 @@ impl From<RawConfig> for Config {
|
|||||||
.or(other.max_seq_len)
|
.or(other.max_seq_len)
|
||||||
.or(other.n_positions);
|
.or(other.n_positions);
|
||||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
let head_dim = other.head_dim.or_else(|| {
|
let hidden_size = other.hidden_size.or(other.n_embd);
|
||||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
let head_dim = other
|
||||||
(Some(hidden_size), _, Some(num_attention_heads))
|
.head_dim
|
||||||
if hidden_size % num_attention_heads == 0 =>
|
.or_else(|| match (hidden_size, other.num_attention_heads) {
|
||||||
{
|
(Some(hidden_size), Some(num_attention_heads))
|
||||||
Some(hidden_size / num_attention_heads)
|
|
||||||
}
|
|
||||||
// Legacy
|
|
||||||
(_, Some(hidden_size), Some(num_attention_heads))
|
|
||||||
if hidden_size % num_attention_heads == 0 =>
|
if hidden_size % num_attention_heads == 0 =>
|
||||||
{
|
{
|
||||||
Some(hidden_size / num_attention_heads)
|
Some(hidden_size / num_attention_heads)
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
});
|
||||||
});
|
let num_heads = other.num_attention_heads;
|
||||||
|
let num_layers = other.num_hidden_layers;
|
||||||
|
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
||||||
|
let intermediate_size = other.intermediate_size;
|
||||||
let model_type = other.model_type;
|
let model_type = other.model_type;
|
||||||
let vision_config = other.vision_config;
|
let vision_config = other.vision_config;
|
||||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
@ -201,6 +246,11 @@ impl From<RawConfig> for Config {
|
|||||||
model_type,
|
model_type,
|
||||||
vision_config,
|
vision_config,
|
||||||
is_encoder_decoder,
|
is_encoder_decoder,
|
||||||
|
hidden_size,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
intermediate_size,
|
||||||
|
num_layers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1439,7 +1489,32 @@ fn spawn_shards(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_type(num_shard: usize) -> Option<String> {
|
#[derive(Debug)]
|
||||||
|
struct ComputeType {
|
||||||
|
count: usize,
|
||||||
|
card: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ComputeType {
|
||||||
|
fn f16_flop(&self) -> Option<u64> {
|
||||||
|
match &self.card[..] {
|
||||||
|
// https://www.nvidia.com/en-us/data-center/l4/
|
||||||
|
"nvidia-l4" => Some(121 * 10u64.pow(12)),
|
||||||
|
card => {
|
||||||
|
tracing::warn!("Unkown compute for card {card}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ComputeType> for OsString {
|
||||||
|
fn from(value: ComputeType) -> Self {
|
||||||
|
format!("{}-{}", value.count, value.card).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_type(num_shard: usize) -> Option<ComputeType> {
|
||||||
let output = Command::new("nvidia-smi")
|
let output = Command::new("nvidia-smi")
|
||||||
.args(["--query-gpu=gpu_name", "--format=csv"])
|
.args(["--query-gpu=gpu_name", "--format=csv"])
|
||||||
.output()
|
.output()
|
||||||
@ -1447,8 +1522,10 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
|||||||
let output = String::from_utf8(output.stdout).ok()?;
|
let output = String::from_utf8(output.stdout).ok()?;
|
||||||
let fullname = output.split('\n').nth(1)?;
|
let fullname = output.split('\n').nth(1)?;
|
||||||
let cardname = fullname.replace(' ', "-").to_lowercase();
|
let cardname = fullname.replace(' ', "-").to_lowercase();
|
||||||
let compute_type = format!("{num_shard}-{cardname}");
|
Some(ComputeType {
|
||||||
Some(compute_type)
|
count: num_shard,
|
||||||
|
card: cardname,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
@ -1700,26 +1777,22 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
// Quantization usually means you're even more RAM constrained.
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
let max_position_embeddings = if let Some(config) = &config {
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
max_default
|
|
||||||
} else {
|
|
||||||
max_position_embeddings
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
max_default
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
max_default
|
|
||||||
};
|
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||||
std::env::set_var("ATTENTION", attention);
|
std::env::set_var("ATTENTION", attention);
|
||||||
|
|
||||||
|
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||||
|
if num_shard > 1 {
|
||||||
|
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
||||||
|
return Err(LauncherError::ArgumentValidation(
|
||||||
|
"Sharding is currently not supported with `exl2` quantization".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
|
}
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||||
@ -1739,9 +1812,19 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||||
None => {
|
None => {
|
||||||
// TODO figure out hardware optimal value
|
// TODO figure out hardware optimal value
|
||||||
let value = 4096.min(max_position_embeddings as u32);
|
let compute_type = compute_type(num_shard);
|
||||||
|
tracing::info!("Compute type {compute_type:?}");
|
||||||
|
tracing::info!("Config {config:?}");
|
||||||
|
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||||
|
let default = compute_optimal.unwrap_or(4096);
|
||||||
|
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
|
||||||
|
let value = if let Some(max_position_embeddings) = max_position_embeddings {
|
||||||
|
default.min(max_position_embeddings)
|
||||||
|
} else {
|
||||||
|
default
|
||||||
|
};
|
||||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||||
value
|
value as u32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1796,16 +1879,6 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
|
||||||
if num_shard > 1 {
|
|
||||||
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
"Sharding is currently not supported with `exl2` quantization".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
if let Some(max_total_tokens) = max_total_tokens {
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
Loading…
Reference in New Issue
Block a user