From d68edc4a2f5b9f5d52565eab9609b5a1fd6db6b4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 9 Jun 2025 22:25:47 -0700 Subject: [PATCH] Qwen2 vl fix Signed-off-by: Wang, Yi A --- .../models/custom_modeling/qwen2_5_vl.py | 62 +++++++------------ .../models/custom_modeling/qwen2_vl.py | 57 ++++++++--------- .../models/flash_causal_lm.py | 15 ++++- .../models/flash_vlm_causal_lm.py | 8 +-- 4 files changed, 64 insertions(+), 78 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 90cada78..7cd651db 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -45,6 +45,10 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py from typing import Union @@ -375,28 +379,6 @@ class Qwen2_5_VLConfig(PretrainedConfig): super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - class Qwen2_5VLAttention(nn.Module): def __init__(self, *, prefix, config, weights): super().__init__() @@ -426,7 +408,8 @@ class Qwen2_5VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -444,12 +427,18 @@ class Qwen2_5VLAttention(nn.Module): query = query.view(*_shape) key = key.view(*_shape) value = value.view(*_shape) - # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) + + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # calc maximum sequence length for any batch query = query.contiguous() @@ -533,11 +522,9 @@ class Qwen2_5VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, _ = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = hidden_states + attn_out norm2_out, _ = self.norm2(hidden_states) mlp_out = self.mlp(norm2_out) @@ -736,6 +723,10 @@ class Qwen2_5VisionModel(nn.Module): ) rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) cu_window_seqlens = torch.tensor( cu_window_seqlens, @@ -762,9 +753,7 @@ class Qwen2_5VisionModel(nn.Module): else: cu_seqlens_now = cu_window_seqlens - hidden_states = block( - hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen - ) + hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen) # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) @@ -886,9 +875,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): full_llm_pos_ids_list = [ item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist ] - # import ipdb - - # ipdb.set_trace() max_s = full_llm_pos_ids_list[-1].max() + 1 final_text_len = input_ids_len - vision_ends[-1] if final_text_len > 0: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 2fddc0e2..d9c07f7d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -44,28 +44,10 @@ from text_generation_server.layers.attention import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2Model, ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output +from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, +) class Qwen2VLAttention(nn.Module): @@ -96,7 +78,8 @@ class Qwen2VLAttention(nn.Module): self, hidden_state: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, max_seqlen: int, ) -> torch.Tensor: # apply the qkv linear layer to the hidden state @@ -116,10 +99,17 @@ class Qwen2VLAttention(nn.Module): value = value.view(*_shape) # apply rotary positional embeddings - query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( - 0 - ) - key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + rotary_dim = cos.shape[-1] + query_rot = query[..., :rotary_dim] + query_pass = query[..., rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape)) + + key_rot = key[..., :rotary_dim] + key_pass = key[..., rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape)) # calc maximum sequence length for any batch query = query.contiguous() @@ -193,11 +183,9 @@ class Qwen2VLVisionBlock(nn.Module): weights=weights, ) - def forward( - self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen - ) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor: norm1_out, residual = self.norm1(hidden_states) - attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen) hidden_states = attn_out + residual norm2_out, residual = self.norm2(hidden_states) hidden_states = hidden_states + self.mlp(norm2_out) @@ -330,6 +318,11 @@ class Qwen2VisionModel(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + cos = rotary_pos_emb.cos() + sin = rotary_pos_emb.sin() + cos = torch.cat((cos, cos), dim=-1).unsqueeze(1) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(1) + # create a cu_seqlens tensor to be used in the attention mask cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -338,7 +331,7 @@ class Qwen2VisionModel(nn.Module): max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) # iterately apply the blocks to the hidden states for block in self.blocks: - hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen) # apply the final patch merger to the hidden states hidden_states = self.merger(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index d00494c9..a9b9f811 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1000,9 +1000,18 @@ class FlashCausalLMBatch(Batch): self.input_ids = F.pad( self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 ) - self.position_ids = F.pad( - self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 - ) + + if self.position_ids.dim() == 2: + # Qwen VL case + self.position_ids = F.pad( + self.position_ids, + (0, 0, 0, padded_bs - self.position_ids.shape[0]), + value=1, + ) + else: + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( self.input_lengths_tensor, (0, padded_bs - self.input_lengths_tensor.shape[0]), diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index f0129013..086c05e7 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -619,9 +619,7 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): if "image_grid_thw" in x[2] ] if image_grid_thw_list: - self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to( - self.input_ids.device - ) + self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0) else: self.image_grid_thw = None @@ -898,7 +896,7 @@ class FlashVlmCausalLM(FlashCausalLM): image_sizes = None if "image_grid_thw" in image_input: - image_grid_thw = image_input["image_grid_thw"].to(device) + image_grid_thw = image_input["image_grid_thw"] else: image_grid_thw = None @@ -992,7 +990,7 @@ class FlashVlmCausalLM(FlashCausalLM): if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( - input_ids, batch.image_grid_thw + input_ids.cpu(), batch.image_grid_thw ) batch.position_ids = position_ids