From 6893eb3834898a658b1a81526c1a6b18ab10e79b Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 27 Jan 2025 16:02:51 +0000 Subject: [PATCH] fix: adjust rotaty init path --- launcher/src/main.rs | 51 +++---------------- .../text_generation_server/layers/rotary.py | 22 ++++---- .../models/custom_modeling/qwen2_vl.py | 3 ++ 3 files changed, 22 insertions(+), 54 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a09ceb31..6391f9eb 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -260,6 +260,11 @@ struct Config { impl Config { fn flop(&self) -> Option { + if self.vision_config.is_some() { + // VLM are much harder to predict and VRAM requirements + // Are more complex. + return None; + } let num_heads = self.num_heads? as u64; let num_kv_heads = self.num_kv_heads? as u64; let head_dim = self.head_dim? as u64; @@ -279,50 +284,8 @@ impl Config { let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; let layer_flops = attn_layer_flops + gate_up_down_flops; - let text_flops = layer_flops * num_layers; - - tracing::debug!("Text flops: {}", human_size(text_flops as usize, "flop")); - - // text-only case - if self.vision_config.is_none() { - return Some(text_flops); - } - - let vision_config = self.vision_config.as_ref().unwrap(); - - // estimate vision flops for specific model types - match self.model_type.as_deref() { - Some("qwen2_vl") => { - let in_chans = vision_config.in_chans? as u64; - let patch_size = vision_config.patch_size? as u64; - let embed_dim = vision_config.embed_dim? as u64; - let vision_depth = vision_config.depth? as u64; - let mlp_ratio = vision_config.mlp_ratio? as u64; - let temporal_patch_size = vision_config.temporal_patch_size? as u64; - // 1. patch embedding: - // - conv3d operation: (t*h*w) * (k_t*k_h*k_w) * c_in * c_out * 2 - // where the 2 accounts for multiply-add - let patch_flops = - 2 * temporal_patch_size * patch_size.pow(2) * embed_dim * in_chans; - // 2. self-attention + mlp: - // - qkv projections: 3 * d_model * d_model * 2 - // - attention: d_model * d_model * 2 - // - mlp: 2 * d_model * (mlp_ratio * d_model) * 2 - // simplified to: 2 * d_model * (4 + mlp_ratio * d_model) - let attn_flops = 2 * embed_dim * (4 + mlp_ratio * embed_dim); - // 3. add with layer norm flops for total vision layer flops - let layer_flops = patch_flops + attn_flops + 2 * embed_dim; - let vision_flops = layer_flops * vision_depth; - tracing::debug!( - "Vision flops: {}", - human_size(vision_flops as usize, "flop") - ); - Some(text_flops + vision_flops) - } - // model has a vision config but is not supported for flops calculation - // we return None to avoid overestimating the memory requirements - _ => None, - } + let total = layer_flops * num_layers; + Some(total) } fn kv_vram_per_tok(&self) -> Option { diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 9f1770ff..7b3500e3 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -101,6 +101,11 @@ class PositionRotaryEmbedding(nn.Module): pass elif rope_type == "default": pass + elif rope_type == "mrope": + mrope_section = rope_scaling["mrope_section"] + return RotaryPositionEmbeddingMultimodalSections( + inv_freq, scaling_factor, mrope_section + ) elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -576,16 +581,6 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): cos: torch.Tensor, sin: torch.Tensor, ): - # process multi-modal rotary embeddings - split_cos, split_sin = [ - torch.split(t, self.sections, dim=-1) for t in (cos, sin) - ] - cos = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1).unsqueeze( - 1 - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1).unsqueeze( - 1 - ) # prepare input tensors q, k = [x.transpose(0, 1).unsqueeze(0) for x in (query, key)] rotary_dim = cos.shape[-1] @@ -624,10 +619,17 @@ class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding): .unsqueeze(-1) .expand(-1, -1, self._cos_cached_exp.shape[-1]) ) + indices = indices.to(dtype=torch.int64) cos_c = torch.gather(self._cos_cached_exp, 1, indices) cos_c = torch.cat([cos_c, cos_c], dim=-1).unsqueeze(1) + split_cos = torch.split(cos_c, self.sections, dim=-1) + cos_c = torch.cat([m[i % 3] for i, m in enumerate(split_cos)], dim=-1) + cos_c = cos_c.unsqueeze(1) sin_c = torch.gather(self._sin_cached_exp, 1, indices) sin_c = torch.cat([sin_c, sin_c], dim=-1).unsqueeze(1) + split_sin = torch.split(sin_c, self.sections, dim=-1) + sin_c = torch.cat([m[i % 3] for i, m in enumerate(split_sin)], dim=-1) + sin_c = sin_c.unsqueeze(1) return cos_c, sin_c diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index e0ae19df..7e296b42 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -377,6 +377,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): self.config = config config.vision_config.quantize = None config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2-VL model at the moment + config.rope_scaling.update({"rope_type": "mrope"}) self.hidden_size = config.hidden_size self.vision_start_token_id = config.vision_start_token_id self.image_token_id = config.image_token_id