mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
launcher: correctly get the head dimension for VLMs
For most (?) VLMs, the head dimension is in the `text_config` configuration section. However, since we only queried the top-level `head_dim` (which typically doesn't exist in VLMs), we would never use flashinfer. This change adds a method that gets the head dimension from the top-level `Config` struct or `text_config` when that fails.
This commit is contained in:
parent
f91434e99b
commit
4727a3af67
@ -152,7 +152,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
"flashdecoding"
|
"flashdecoding"
|
||||||
};
|
};
|
||||||
|
|
||||||
match config.head_dim {
|
match config.get_head_dim() {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||||
@ -214,6 +214,7 @@ struct RawConfig {
|
|||||||
num_key_value_heads: Option<usize>,
|
num_key_value_heads: Option<usize>,
|
||||||
num_hidden_layers: Option<usize>,
|
num_hidden_layers: Option<usize>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
|
text_config: Option<TextConfig>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
is_encoder_decoder: Option<bool>,
|
is_encoder_decoder: Option<bool>,
|
||||||
#[serde(rename = "num_experts_per_tok")]
|
#[serde(rename = "num_experts_per_tok")]
|
||||||
@ -233,6 +234,11 @@ struct QuantizationConfig {
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct VisionConfig {}
|
struct VisionConfig {}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct TextConfig {
|
||||||
|
head_dim: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
@ -244,6 +250,7 @@ struct Config {
|
|||||||
intermediate_size: Option<usize>,
|
intermediate_size: Option<usize>,
|
||||||
hidden_size: Option<usize>,
|
hidden_size: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
|
text_config: Option<TextConfig>,
|
||||||
vision_config: Option<VisionConfig>,
|
vision_config: Option<VisionConfig>,
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
num_experts_per_token: usize,
|
num_experts_per_token: usize,
|
||||||
@ -253,6 +260,14 @@ struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
fn get_head_dim(&self) -> Option<usize> {
|
||||||
|
self.head_dim.or_else(|| {
|
||||||
|
self.text_config
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|text_config| text_config.head_dim)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn flop(&self) -> Option<u64> {
|
fn flop(&self) -> Option<u64> {
|
||||||
if self.vision_config.is_some() {
|
if self.vision_config.is_some() {
|
||||||
// VLM are much harder to predict and VRAM requirements
|
// VLM are much harder to predict and VRAM requirements
|
||||||
@ -261,7 +276,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
let num_heads = self.num_heads? as u64;
|
let num_heads = self.num_heads? as u64;
|
||||||
let num_kv_heads = self.num_kv_heads? as u64;
|
let num_kv_heads = self.num_kv_heads? as u64;
|
||||||
let head_dim = self.head_dim? as u64;
|
let head_dim = self.get_head_dim()? as u64;
|
||||||
let hidden_size = self.hidden_size? as u64;
|
let hidden_size = self.hidden_size? as u64;
|
||||||
let intermediate_size = (self.intermediate_size?
|
let intermediate_size = (self.intermediate_size?
|
||||||
* (self.num_experts_per_token + self.num_shared_experts))
|
* (self.num_experts_per_token + self.num_shared_experts))
|
||||||
@ -289,7 +304,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
// 2 for key and values
|
// 2 for key and values
|
||||||
// 2 for f16 dtype?
|
// 2 for f16 dtype?
|
||||||
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?)
|
Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mlp_vram_per_tok(&self) -> Option<usize> {
|
fn mlp_vram_per_tok(&self) -> Option<usize> {
|
||||||
@ -310,8 +325,8 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn model_vram(&self) -> Option<usize> {
|
fn model_vram(&self) -> Option<usize> {
|
||||||
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?;
|
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
|
||||||
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?;
|
let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
|
||||||
// gate + up + down = 3
|
// gate + up + down = 3
|
||||||
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
|
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
|
||||||
let layer_vram = mlp_vram + attn_vram + o_vram;
|
let layer_vram = mlp_vram + attn_vram + o_vram;
|
||||||
@ -349,6 +364,7 @@ impl From<RawConfig> for Config {
|
|||||||
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
||||||
let intermediate_size = other.intermediate_size;
|
let intermediate_size = other.intermediate_size;
|
||||||
let model_type = other.model_type;
|
let model_type = other.model_type;
|
||||||
|
let text_config = other.text_config;
|
||||||
let vision_config = other.vision_config;
|
let vision_config = other.vision_config;
|
||||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
|
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
|
||||||
@ -360,6 +376,7 @@ impl From<RawConfig> for Config {
|
|||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
model_type,
|
model_type,
|
||||||
|
text_config,
|
||||||
vision_config,
|
vision_config,
|
||||||
is_encoder_decoder,
|
is_encoder_decoder,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user