mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
launcher: correctly get the head dimension for VLMs
For most (?) VLMs, the head dimension is in the `text_config` configuration section. However, since we only queried the top-level `head_dim` (which typically doesn't exist in VLMs), we would never use flashinfer. This change adds a method that gets the head dimension from the top-level `Config` struct or `text_config` when that fails.
This commit is contained in:
parent
f91434e99b
commit
4727a3af67
@ -152,7 +152,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
"flashdecoding"
|
||||
};
|
||||
|
||||
match config.head_dim {
|
||||
match config.get_head_dim() {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
@ -214,6 +214,7 @@ struct RawConfig {
|
||||
num_key_value_heads: Option<usize>,
|
||||
num_hidden_layers: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
text_config: Option<TextConfig>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
#[serde(rename = "num_experts_per_tok")]
|
||||
@ -233,6 +234,11 @@ struct QuantizationConfig {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VisionConfig {}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TextConfig {
|
||||
head_dim: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
@ -244,6 +250,7 @@ struct Config {
|
||||
intermediate_size: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
text_config: Option<TextConfig>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
num_experts_per_token: usize,
|
||||
@ -253,6 +260,14 @@ struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn get_head_dim(&self) -> Option<usize> {
|
||||
self.head_dim.or_else(|| {
|
||||
self.text_config
|
||||
.as_ref()
|
||||
.and_then(|text_config| text_config.head_dim)
|
||||
})
|
||||
}
|
||||
|
||||
fn flop(&self) -> Option<u64> {
|
||||
if self.vision_config.is_some() {
|
||||
// VLM are much harder to predict and VRAM requirements
|
||||
@ -261,7 +276,7 @@ impl Config {
|
||||
}
|
||||
let num_heads = self.num_heads? as u64;
|
||||
let num_kv_heads = self.num_kv_heads? as u64;
|
||||
let head_dim = self.head_dim? as u64;
|
||||
let head_dim = self.get_head_dim()? as u64;
|
||||
let hidden_size = self.hidden_size? as u64;
|
||||
let intermediate_size = (self.intermediate_size?
|
||||
* (self.num_experts_per_token + self.num_shared_experts))
|
||||
@ -289,7 +304,7 @@ impl Config {
|
||||
}
|
||||
// 2 for key and values
|
||||
// 2 for f16 dtype?
|
||||
Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?)
|
||||
Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)
|
||||
}
|
||||
|
||||
fn mlp_vram_per_tok(&self) -> Option<usize> {
|
||||
@ -310,8 +325,8 @@ impl Config {
|
||||
}
|
||||
|
||||
fn model_vram(&self) -> Option<usize> {
|
||||
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?;
|
||||
let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?;
|
||||
let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;
|
||||
let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;
|
||||
// gate + up + down = 3
|
||||
let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;
|
||||
let layer_vram = mlp_vram + attn_vram + o_vram;
|
||||
@ -349,6 +364,7 @@ impl From<RawConfig> for Config {
|
||||
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
||||
let intermediate_size = other.intermediate_size;
|
||||
let model_type = other.model_type;
|
||||
let text_config = other.text_config;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);
|
||||
@ -360,6 +376,7 @@ impl From<RawConfig> for Config {
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
text_config,
|
||||
vision_config,
|
||||
is_encoder_decoder,
|
||||
hidden_size,
|
||||
|
Loading…
Reference in New Issue
Block a user