mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Qwen2 vl fix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
93e5e35f9d
commit
d68edc4a2f
@ -45,6 +45,10 @@ from text_generation_server.layers.attention import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||
from typing import Union
|
||||
@ -375,28 +379,6 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2_5VLAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
@ -426,7 +408,8 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
@ -444,12 +427,18 @@ class Qwen2_5VLAttention(nn.Module):
|
||||
query = query.view(*_shape)
|
||||
key = key.view(*_shape)
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
|
||||
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
@ -533,11 +522,9 @@ class Qwen2_5VLVisionBlock(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
|
||||
norm1_out, _ = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
|
||||
hidden_states = hidden_states + attn_out
|
||||
norm2_out, _ = self.norm2(hidden_states)
|
||||
mlp_out = self.mlp(norm2_out)
|
||||
@ -736,6 +723,10 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
)
|
||||
|
||||
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
|
||||
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
|
||||
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
@ -762,9 +753,7 @@ class Qwen2_5VisionModel(nn.Module):
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
hidden_states = block(
|
||||
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
|
||||
)
|
||||
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
@ -886,9 +875,6 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
# import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
|
@ -44,28 +44,10 @@ from text_generation_server.layers.attention import (
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
@ -96,7 +78,8 @@ class Qwen2VLAttention(nn.Module):
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
@ -116,10 +99,17 @@ class Qwen2VLAttention(nn.Module):
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
|
||||
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
@ -193,11 +183,9 @@ class Qwen2VLVisionBlock(nn.Module):
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
|
||||
norm1_out, residual = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
|
||||
hidden_states = attn_out + residual
|
||||
norm2_out, residual = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(norm2_out)
|
||||
@ -330,6 +318,11 @@ class Qwen2VisionModel(nn.Module):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
|
||||
|
||||
cos = rotary_pos_emb.cos()
|
||||
sin = rotary_pos_emb.sin()
|
||||
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
|
||||
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
|
||||
|
||||
# create a cu_seqlens tensor to be used in the attention mask
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
@ -338,7 +331,7 @@ class Qwen2VisionModel(nn.Module):
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
# iterately apply the blocks to the hidden states
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
@ -1000,6 +1000,15 @@ class FlashCausalLMBatch(Batch):
|
||||
self.input_ids = F.pad(
|
||||
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0
|
||||
)
|
||||
|
||||
if self.position_ids.dim() == 2:
|
||||
# Qwen VL case
|
||||
self.position_ids = F.pad(
|
||||
self.position_ids,
|
||||
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
|
||||
value=1,
|
||||
)
|
||||
else:
|
||||
self.position_ids = F.pad(
|
||||
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
|
||||
)
|
||||
|
@ -619,9 +619,7 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if "image_grid_thw" in x[2]
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
|
||||
self.input_ids.device
|
||||
)
|
||||
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
else:
|
||||
self.image_grid_thw = None
|
||||
|
||||
@ -898,7 +896,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
image_sizes = None
|
||||
|
||||
if "image_grid_thw" in image_input:
|
||||
image_grid_thw = image_input["image_grid_thw"].to(device)
|
||||
image_grid_thw = image_input["image_grid_thw"]
|
||||
else:
|
||||
image_grid_thw = None
|
||||
|
||||
@ -992,7 +990,7 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
input_ids.cpu(), batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user