Qwen2 vl fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-06-09 22:25:47 -07:00
parent 93e5e35f9d
commit d68edc4a2f
4 changed files with 64 additions and 78 deletions

View File

@ -45,6 +45,10 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model, Qwen2Model,
) )
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union from typing import Union
@ -375,28 +379,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module): class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights): def __init__(self, *, prefix, config, weights):
super().__init__() super().__init__()
@ -426,7 +408,8 @@ class Qwen2_5VLAttention(nn.Module):
self, self,
hidden_state: torch.Tensor, hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int, max_seqlen: int,
) -> torch.Tensor: ) -> torch.Tensor:
# apply the qkv linear layer to the hidden state # apply the qkv linear layer to the hidden state
@ -444,12 +427,18 @@ class Qwen2_5VLAttention(nn.Module):
query = query.view(*_shape) query = query.view(*_shape)
key = key.view(*_shape) key = key.view(*_shape)
value = value.view(*_shape) value = value.view(*_shape)
# apply rotary positional embeddings # apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
0 rotary_dim = cos.shape[-1]
) query_rot = query[..., :rotary_dim]
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch # calc maximum sequence length for any batch
query = query.contiguous() query = query.contiguous()
@ -533,11 +522,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
weights=weights, weights=weights,
) )
def forward( def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states) norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = hidden_states + attn_out hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states) norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out) mlp_out = self.mlp(norm2_out)
@ -736,6 +723,10 @@ class Qwen2_5VisionModel(nn.Module):
) )
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
cu_window_seqlens = torch.tensor( cu_window_seqlens = torch.tensor(
cu_window_seqlens, cu_window_seqlens,
@ -762,9 +753,7 @@ class Qwen2_5VisionModel(nn.Module):
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
hidden_states = block( hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
# apply the final patch merger to the hidden states # apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
@ -886,9 +875,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
full_llm_pos_ids_list = [ full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
] ]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1 max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1] final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0: if final_text_len > 0:

View File

@ -44,28 +44,10 @@ from text_generation_server.layers.attention import (
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model, Qwen2Model,
) )
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
# Copied from transformers.models.llama.modeling_llama.rotate_half apply_rotary_pos_emb,
def rotate_half(x): )
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2VLAttention(nn.Module): class Qwen2VLAttention(nn.Module):
@ -96,7 +78,8 @@ class Qwen2VLAttention(nn.Module):
self, self,
hidden_state: torch.Tensor, hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int, max_seqlen: int,
) -> torch.Tensor: ) -> torch.Tensor:
# apply the qkv linear layer to the hidden state # apply the qkv linear layer to the hidden state
@ -116,10 +99,17 @@ class Qwen2VLAttention(nn.Module):
value = value.view(*_shape) value = value.view(*_shape)
# apply rotary positional embeddings # apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
0 rotary_dim = cos.shape[-1]
) query_rot = query[..., :rotary_dim]
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# calc maximum sequence length for any batch # calc maximum sequence length for any batch
query = query.contiguous() query = query.contiguous()
@ -193,11 +183,9 @@ class Qwen2VLVisionBlock(nn.Module):
weights=weights, weights=weights,
) )
def forward( def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states) norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = attn_out + residual hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states) norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out) hidden_states = hidden_states + self.mlp(norm2_out)
@ -330,6 +318,11 @@ class Qwen2VisionModel(nn.Module):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
# create a cu_seqlens tensor to be used in the attention mask # create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
@ -338,7 +331,7 @@ class Qwen2VisionModel(nn.Module):
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states # iterately apply the blocks to the hidden states
for block in self.blocks: for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen) hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
# apply the final patch merger to the hidden states # apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)

View File

@ -1000,9 +1000,18 @@ class FlashCausalLMBatch(Batch):
self.input_ids = F.pad( self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
) )
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 if self.position_ids.dim() == 2:
) # Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)
self.input_lengths_tensor = F.pad( self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, self.input_lengths_tensor,
(0, padded_bs - self.input_lengths_tensor.shape[0]), (0, padded_bs - self.input_lengths_tensor.shape[0]),

View File

@ -619,9 +619,7 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
if "image_grid_thw" in x[2] if "image_grid_thw" in x[2]
] ]
if image_grid_thw_list: if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to( self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
self.input_ids.device
)
else: else:
self.image_grid_thw = None self.image_grid_thw = None
@ -898,7 +896,7 @@ class FlashVlmCausalLM(FlashCausalLM):
image_sizes = None image_sizes = None
if "image_grid_thw" in image_input: if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"].to(device) image_grid_thw = image_input["image_grid_thw"]
else: else:
image_grid_thw = None image_grid_thw = None
@ -992,7 +990,7 @@ class FlashVlmCausalLM(FlashCausalLM):
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}: if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling: if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw input_ids.cpu(), batch.image_grid_thw
) )
batch.position_ids = position_ids batch.position_ids = position_ids