mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-29 03:10:18 +00:00
fix: adjust rotaty init path
This commit is contained in:
parent
5f416f6e28
commit
6893eb3834
@ -260,6 +260,11 @@ struct Config {
|
|||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn flop(&self) -> Option<u64> {
|
fn flop(&self) -> Option<u64> {
|
||||||
|
if self.vision_config.is_some() {
|
||||||
|
// VLM are much harder to predict and VRAM requirements
|
||||||
|
// Are more complex.
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let num_heads = self.num_heads? as u64;
|
let num_heads = self.num_heads? as u64;
|
||||||
let num_kv_heads = self.num_kv_heads? as u64;
|
let num_kv_heads = self.num_kv_heads? as u64;
|
||||||
let head_dim = self.head_dim? as u64;
|
let head_dim = self.head_dim? as u64;
|
||||||
@ -279,50 +284,8 @@ impl Config {
|
|||||||
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
||||||
|
|
||||||
let layer_flops = attn_layer_flops + gate_up_down_flops;
|
let layer_flops = attn_layer_flops + gate_up_down_flops;
|
||||||
let text_flops = layer_flops * num_layers;
|
let total = layer_flops * num_layers;
|
||||||
|
Some(total)
|
||||||
tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop"));
|
|
||||||
|
|
||||||
// text-only case
|
|
||||||
if self.vision_config.is_none() {
|
|
||||||
return Some(text_flops);
|
|
||||||
}
|
|
||||||
|
|
||||||
let vision_config = self.vision_config.as_ref().unwrap();
|
|
||||||
|
|
||||||
// estimate vision flops for specific model types
|
|
||||||
match self.model_type.as_deref() {
|
|
||||||
Some("qwen2_vl") => {
|
|
||||||
let in_chans = vision_config.in_chans? as u64;
|
|
||||||
let patch_size = vision_config.patch_size? as u64;
|
|
||||||
let embed_dim = vision_config.embed_dim? as u64;
|
|
||||||
let vision_depth = vision_config.depth? as u64;
|
|
||||||
let mlp_ratio = vision_config.mlp_ratio? as u64;
|
|
||||||
let temporal_patch_size = vision_config.temporal_patch_size? as u64;
|
|
||||||
// 1. patch embedding:
|
|
||||||
// - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2
|
|
||||||
// where the 2 accounts for multiply-add
|
|
||||||
let patch_flops =
|
|
||||||
2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans;
|
|
||||||
// 2. self-attention + mlp:
|
|
||||||
// - qkv projections: 3 * d_model * d_model * 2
|
|
||||||
// - attention: d_model * d_model * 2
|
|
||||||
// - mlp: 2 * d_model * (mlp_ratio * d_model) * 2
|
|
||||||
// simplified to: 2 * d_model * (4 + mlp_ratio * d_model)
|
|
||||||
let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim);
|
|
||||||
// 3. add with layer norm flops for total vision layer flops
|
|
||||||
let layer_flops = patch_flops + attn_flops + 2 * embed_dim;
|
|
||||||
let vision_flops = layer_flops * vision_depth;
|
|
||||||
tracing::debug!(
|
|
||||||
"Vision flops: {}",
|
|
||||||
human_size(vision_flops as usize, "flop")
|
|
||||||
);
|
|
||||||
Some(text_flops + vision_flops)
|
|
||||||
}
|
|
||||||
// model has a vision config but is not supported for flops calculation
|
|
||||||
// we return None to avoid overestimating the memory requirements
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kv_vram_per_tok(&self) -> Option<usize> {
|
fn kv_vram_per_tok(&self) -> Option<usize> {
|
||||||
|
@ -101,6 +101,11 @@ class PositionRotaryEmbedding(nn.Module):
|
|||||||
pass
|
pass
|
||||||
elif rope_type == "default":
|
elif rope_type == "default":
|
||||||
pass
|
pass
|
||||||
|
elif rope_type == "mrope":
|
||||||
|
mrope_section = rope_scaling["mrope_section"]
|
||||||
|
return RotaryPositionEmbeddingMultimodalSections(
|
||||||
|
inv_freq, scaling_factor, mrope_section
|
||||||
|
)
|
||||||
elif rope_type == "dynamic":
|
elif rope_type == "dynamic":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
@ -576,16 +581,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
):
|
):
|
||||||
# process multi-modal rotary embeddings
|
|
||||||
split_cos, split_sin = [
|
|
||||||
torch.split(t, self.sections, dim=-1) for t in (cos, sin)
|
|
||||||
]
|
|
||||||
cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze(
|
|
||||||
1
|
|
||||||
)
|
|
||||||
sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze(
|
|
||||||
1
|
|
||||||
)
|
|
||||||
# prepare input tensors
|
# prepare input tensors
|
||||||
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
|
q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)]
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
@ -624,10 +619,17 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
|||||||
.unsqueeze(-1)
|
.unsqueeze(-1)
|
||||||
.expand(-1, -1, self._cos_cached_exp.shape[-1])
|
.expand(-1, -1, self._cos_cached_exp.shape[-1])
|
||||||
)
|
)
|
||||||
|
indices = indices.to(dtype=torch.int64)
|
||||||
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
|
cos_c = torch.gather(self._cos_cached_exp, 1, indices)
|
||||||
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
|
cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1)
|
||||||
|
split_cos = torch.split(cos_c, self.sections, dim=-1)
|
||||||
|
cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1)
|
||||||
|
cos_c = cos_c.unsqueeze(1)
|
||||||
|
|
||||||
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
|
sin_c = torch.gather(self._sin_cached_exp, 1, indices)
|
||||||
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
|
sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1)
|
||||||
|
split_sin = torch.split(sin_c, self.sections, dim=-1)
|
||||||
|
sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1)
|
||||||
|
sin_c = sin_c.unsqueeze(1)
|
||||||
|
|
||||||
return cos_c, sin_c
|
return cos_c, sin_c
|
||||||
|
@ -377,6 +377,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
config.vision_config.quantize = None
|
config.vision_config.quantize = None
|
||||||
config.vision_config.speculator = config.speculator
|
config.vision_config.speculator = config.speculator
|
||||||
|
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||||
|
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
|
||||||
|
config.rope_scaling.update({"rope_type": "mrope"})
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.vision_start_token_id = config.vision_start_token_id
|
self.vision_start_token_id = config.vision_start_token_id
|
||||||
self.image_token_id = config.image_token_id
|
self.image_token_id = config.image_token_id
|
||||||
|
Loading…
Reference in New Issue
Block a user