Qwen2 vl fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-09 22:25:47 -07:00
parent 93e5e35f9d
commit d68edc4a2f
4 changed files with 64 additions and 78 deletions

View File

@ -45,6 +45,10 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
@ -375,28 +379,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
@ -426,7 +408,8 @@ class Qwen2_5VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -444,12 +427,18 @@ class Qwen2_5VLAttention(nn.Module):
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
@ -533,11 +522,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
@ -736,6 +723,10 @@ class Qwen2_5VisionModel(nn.Module):
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
@ -762,9 +753,7 @@ class Qwen2_5VisionModel(nn.Module):
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
@ -886,9 +875,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:

View File

@ -44,28 +44,10 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
class Qwen2VLAttention(nn.Module):
@ -96,7 +78,8 @@ class Qwen2VLAttention(nn.Module):
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
@ -116,10 +99,17 @@ class Qwen2VLAttention(nn.Module):
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch
query = query.contiguous()
@ -193,11 +183,9 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
@ -330,6 +318,11 @@ class Qwen2VisionModel(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -338,7 +331,7 @@ class Qwen2VisionModel(nn.Module):
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)

View File

@ -1000,6 +1000,15 @@ class FlashCausalLMBatch(Batch):
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
)
if self.position_ids.dim() == 2:
# Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)

View File

@ -619,9 +619,7 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
self.input_ids.device
)
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
else:
self.image_grid_thw = None
@ -898,7 +896,7 @@ class FlashVlmCausalLM(FlashCausalLM):
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"].to(device)
image_grid_thw = image_input["image_grid_thw"]
else:
image_grid_thw = None
@ -992,7 +990,7 @@ class FlashVlmCausalLM(FlashCausalLM):
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
input_ids.cpu(), batch.image_grid_thw
)
batch.position_ids = position_ids