mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: add support for qwen2 vl model
This commit is contained in:
parent
98330df65e
commit
d96eef2a02
@ -89,6 +89,8 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_type == "default":
|
||||
pass
|
||||
elif rope_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
@ -275,6 +277,32 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def get_cos_sin_hack(
|
||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||
):
|
||||
# TODO: avoid always computing, use the cache and update it if necessary
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, None, :, None]
|
||||
.float()
|
||||
.expand(3, position_ids.shape[1], -1, 1)
|
||||
)
|
||||
|
||||
position_ids_expanded = position_ids[
|
||||
:, :, None, :
|
||||
].float() # shape (3, bs, 1, positions)
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
|
||||
2, 3
|
||||
)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype)
|
||||
sin = emb.sin().to(dtype)
|
||||
|
||||
# Update cached values
|
||||
self._cos_cached = cos
|
||||
self._sin_cached = sin
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
|
@ -144,7 +144,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
linear = get_linear(weight, bias)
|
||||
|
@ -146,6 +146,9 @@ try:
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||
@ -275,6 +278,11 @@ class ModelType(enum.Enum):
|
||||
"name": "Qwen 2",
|
||||
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
|
||||
}
|
||||
QWEN2_VL = {
|
||||
"type": "qwen2_vl",
|
||||
"name": "Qwen 2 VL",
|
||||
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||
}
|
||||
OPT = {
|
||||
"type": "opt",
|
||||
"name": "Opt",
|
||||
@ -1193,6 +1201,18 @@ def get_model(
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == QWEN2_VL:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2VLForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
if model_type == MLLAMA:
|
||||
if FLASH_ATTENTION:
|
||||
return MllamaCausalLM(
|
||||
|
@ -49,6 +49,13 @@ def _load_gqa(config, prefix: str, weights):
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
class Qwen2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -61,6 +68,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.mrope_section = config.rope_scaling.get("mrope_section", None)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
@ -122,7 +130,28 @@ class Qwen2Attention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# TODO: correctly handle the multimodal case
|
||||
if False:
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
else:
|
||||
# multimodal rotary
|
||||
unsqueeze_dim = 1
|
||||
mrope_section = self.mrope_section * 2
|
||||
cos = torch.cat(
|
||||
[m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))],
|
||||
dim=-1,
|
||||
).unsqueeze(unsqueeze_dim)
|
||||
sin = torch.cat(
|
||||
[m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))],
|
||||
dim=-1,
|
||||
).unsqueeze(unsqueeze_dim)
|
||||
|
||||
_query = query.transpose(0, 1).unsqueeze(0)
|
||||
_key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0)
|
||||
q_embed = (_query * cos) + (rotate_half(_query) * sin)
|
||||
k_embed = (_key * cos) + (rotate_half(_key) * sin)
|
||||
query = q_embed.squeeze(0).transpose(0, 1)
|
||||
kv[:, 0] = k_embed.squeeze(0).transpose(0, 1)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
@ -306,12 +335,20 @@ class Qwen2Model(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds.squeeze(0)
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
# TODO: fix how N-D position_ids are handled
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack(
|
||||
position_ids, true_max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
|
544
server/text_generation_server/models/custom_modeling/qwen2_vl.py
Normal file
544
server/text_generation_server/models/custom_modeling/qwen2_vl.py
Normal file
@ -0,0 +1,544 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen2 VL model."""
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
else:
|
||||
import flash_attn_2_cuda
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VLSdpaAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.embed_dim
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.num_heads = config.num_heads // weights.process_group.size()
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_heads,
|
||||
)
|
||||
|
||||
self.proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
qkv = self.qkv(hidden_state)
|
||||
query, key, value = qkv.split(
|
||||
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||
)
|
||||
|
||||
# reshape the query, key, and value tensors
|
||||
_shape = (
|
||||
hidden_state.shape[0],
|
||||
self.num_heads,
|
||||
self.embed_dim // self.num_heads,
|
||||
)
|
||||
query = query.view(*_shape)
|
||||
key = key.view(*_shape)
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
# TODO: make use of existing RotatoryPositionEmbedding class
|
||||
|
||||
# create the attention mask
|
||||
attention_mask = torch.zeros(
|
||||
[1, hidden_state.shape[0], hidden_state.shape[0]],
|
||||
device=hidden_state.device,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
|
||||
|
||||
# apply the cu_seqlens to the attention mask
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[
|
||||
...,
|
||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
||||
cu_seqlens[i - 1] : cu_seqlens[i],
|
||||
] = True
|
||||
|
||||
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
|
||||
# apply attention
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attention_mask, dropout_p=0.0
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
# TODO: prefer flash attention
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2VLVisionMLP(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLVisionBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.attn = Qwen2VLSdpaAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.norm1 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.norm1",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.norm2 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.norm2",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.mlp = Qwen2VLVisionMLP(
|
||||
prefix=f"{prefix}.mlp",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
hidden_states_post_norm1, res = self.norm1(hidden_states)
|
||||
hidden_states = hidden_states + self.attn(
|
||||
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
|
||||
)
|
||||
hidden_states_post_norm2, res = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLPatchMerger(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
context_dim = 2560
|
||||
spatial_merge_size: int = 2
|
||||
self.hidden_size = 5120 # context_dim * (spatial_merge_size**2)
|
||||
self.patch_merger_ln_q = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_q",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
|
||||
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VisionModel(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_channels=config.in_chans,
|
||||
out_channels=config.embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||
)
|
||||
head_dim = config.embed_dim // config.num_heads
|
||||
# TODO: replace with static positional embeddings once implemented
|
||||
theta = 10000.0
|
||||
dim = head_dim // 2
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2VLVisionBlock(
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for i in range(config.depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2VLPatchMerger(
|
||||
prefix=f"{prefix}.merger",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
self.embed_dim = config.embed_dim
|
||||
|
||||
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, hidden_size = hidden_state.shape
|
||||
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||
return hidden_state
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# reshape the input tensor for processing
|
||||
shape = (
|
||||
-1,
|
||||
self.in_channels,
|
||||
self.temporal_patch_size,
|
||||
self.spatial_patch_size,
|
||||
self.spatial_patch_size,
|
||||
)
|
||||
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||
# TODO: revisit to see if we can avoid some of these reshapes
|
||||
|
||||
# find the position ids for the input tensor based on the grid_thw
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
|
||||
# apply the positional embeddings to the position ids
|
||||
seq = torch.arange(
|
||||
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
|
||||
|
||||
# create a cu_seqlens tensor to be used in the attention mask
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
# iterately apply the blocks to the hidden states
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states, grid_thw)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_start_token_id = config.vision_start_token_id
|
||||
self.image_token_id = config.image_token_id
|
||||
self.video_token_id = config.video_token_id
|
||||
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
|
||||
self.visual = Qwen2VisionModel(
|
||||
prefix="visual", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
|
||||
# make an attention_mask that is the same size as the input_ids
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
# input embeddings are masked with image embeddings
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
# handle the position_ids taking the multimodal inputs into account
|
||||
mrope_position_deltas = []
|
||||
if image_grid_thw is not None or video_grid_thw is not None:
|
||||
total_input_ids = input_ids
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
input_ids.shape[0],
|
||||
input_ids.shape[1],
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
if attention_mask is not None:
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
image_nums, video_nums = 0, 0
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_ids == self.vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
# determine the number of images and videos in the input
|
||||
image_nums = (vision_tokens == self.image_token_id).sum()
|
||||
video_nums = (vision_tokens == self.video_token_id).sum()
|
||||
input_tokens = input_ids.tolist()
|
||||
llm_pos_ids_list: list = []
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
# process each input based on it's token type and grid size
|
||||
for _ in range(image_nums + video_nums):
|
||||
if self.image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(self.image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if self.video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(self.video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t.item(),
|
||||
h.item() // self.spatial_merge_size,
|
||||
w.item() // self.spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1
|
||||
if len(llm_pos_ids_list) > 0
|
||||
else 0
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1
|
||||
if len(llm_pos_ids_list) > 0
|
||||
else 0
|
||||
)
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
||||
position_ids.device
|
||||
)
|
||||
mrope_position_deltas.append(
|
||||
llm_positions.max() + 1 - len(total_input_ids[i])
|
||||
)
|
||||
mrope_position_deltas = torch.tensor(
|
||||
mrope_position_deltas, device=input_ids.device
|
||||
).unsqueeze(1)
|
||||
|
||||
# TODO: adjust model to accept 2D position_ids
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return outputs, None
|
@ -67,6 +67,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
elif config.model_type == "qwen2_vl":
|
||||
return "<image>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -137,6 +139,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
@ -145,6 +148,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
@ -153,6 +157,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
@ -178,6 +183,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
# TODO: REMOVE (this is for debugging purposes)
|
||||
images = images[0][0].resize(
|
||||
(images[0][0].width * 2, images[0][0].height * 2)
|
||||
)
|
||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||
else:
|
||||
image_inputs = None
|
||||
@ -237,10 +246,15 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
if "image_grid_thw" in image_inputs:
|
||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
|
||||
@ -381,8 +395,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
# TODO: remove the unsqueeze(0)
|
||||
input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids.unsqueeze(0),
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
@ -394,6 +409,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
@ -403,6 +419,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
Loading…
Reference in New Issue
Block a user