feat: add support for qwen2 vl model

This commit is contained in:
David Holtz 2024-10-24 15:36:53 +00:00 committed by drbh
parent 98330df65e
commit d96eef2a02
6 changed files with 653 additions and 6 deletions

View File

@ -89,6 +89,8 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
@ -275,6 +277,32 @@ class PositionRotaryEmbedding(nn.Module):
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
def get_cos_sin_hack(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
):
# TODO: avoid always computing, use the cache and update it if necessary
inv_freq_expanded = (
self.inv_freq[None, None, :, None]
.float()
.expand(3, position_ids.shape[1], -1, 1)
)
position_ids_expanded = position_ids[
:, :, None, :
].float() # shape (3, bs, 1, positions)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(
2, 3
)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
# Update cached values
self._cos_cached = cos
self._sin_cached = sin
return cos, sin
class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(

View File

@ -144,7 +144,7 @@ class TensorParallelColumnLinear(SuperLayer):
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias)

View File

@ -146,6 +146,9 @@ try:
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
@ -275,6 +278,11 @@ class ModelType(enum.Enum):
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
OPT = {
"type": "opt",
"name": "Opt",
@ -1193,6 +1201,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(

View File

@ -49,6 +49,13 @@ def _load_gqa(config, prefix: str, weights):
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class Qwen2Attention(torch.nn.Module):
def __init__(
self,
@ -61,6 +68,7 @@ class Qwen2Attention(torch.nn.Module):
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = config.rope_scaling.get("mrope_section", None)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
@ -122,7 +130,28 @@ class Qwen2Attention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# TODO: correctly handle the multimodal case
if False:
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
else:
# multimodal rotary
unsqueeze_dim = 1
mrope_section = self.mrope_section * 2
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))],
dim=-1,
).unsqueeze(unsqueeze_dim)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))],
dim=-1,
).unsqueeze(unsqueeze_dim)
_query = query.transpose(0, 1).unsqueeze(0)
_key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0)
q_embed = (_query * cos) + (rotate_half(_query) * sin)
k_embed = (_key * cos) + (rotate_half(_key) * sin)
query = q_embed.squeeze(0).transpose(0, 1)
kv[:, 0] = k_embed.squeeze(0).transpose(0, 1)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
@ -306,12 +335,20 @@ class Qwen2Model(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids
if inputs_embeds is not None:
hidden_states = inputs_embeds.squeeze(0)
else:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
# TODO: fix how N-D position_ids are handled
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack(
position_ids, true_max_s, hidden_states.dtype
)

View File

@ -0,0 +1,544 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import flash_attn_2_cuda
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2VLSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.embed_dim
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=True,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# TODO: make use of existing RotatoryPositionEmbedding class
# create the attention mask
attention_mask = torch.zeros(
[1, hidden_state.shape[0], hidden_state.shape[0]],
device=hidden_state.device,
dtype=torch.bool,
)
# TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it
# apply the cu_seqlens to the attention mask
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True
# transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
# apply attention
attn_output = F.scaled_dot_product_attention(
query, key, value, attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
# TODO: prefer flash attention
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2VLSdpaAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastLayerNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastLayerNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
hidden_states_post_norm1, res = self.norm1(hidden_states)
hidden_states = hidden_states + self.attn(
hidden_states_post_norm1, cu_seqlens, rotary_pos_emb
)
hidden_states_post_norm2, res = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(hidden_states_post_norm2)
return hidden_states
class Qwen2VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
context_dim = 2560
spatial_merge_size: int = 2
self.hidden_size = 5120 # context_dim * (spatial_merge_size**2)
self.patch_merger_ln_q = FastLayerNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states, grid_thw) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_chans,
out_channels=config.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.embed_dim // config.num_heads
# TODO: replace with static positional embeddings once implemented
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.embed_dim
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
# iterately apply the blocks to the hidden states
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states, grid_thw)
return hidden_states
class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.visual = Qwen2VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
# make an attention_mask that is the same size as the input_ids
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
inputs_embeds = self.text_model.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
# input embeddings are masked with image embeddings
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
# handle the position_ids taking the multimodal inputs into account
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
total_input_ids = input_ids
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
if attention_mask is not None:
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == self.vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
# determine the number of images and videos in the input
image_nums = (vision_tokens == self.image_token_id).sum()
video_nums = (vision_tokens == self.video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
# process each input based on it's token type and grid size
for _ in range(image_nums + video_nums):
if self.image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(self.image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if self.video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(self.video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // self.spatial_merge_size,
w.item() // self.spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
# TODO: adjust model to accept 2D position_ids
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
return outputs, None

View File

@ -67,6 +67,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl":
return "<image>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -137,6 +139,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
@ -145,6 +148,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
@ -153,6 +157,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@classmethod
@ -178,6 +183,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
# TODO: REMOVE (this is for debugging purposes)
images = images[0][0].resize(
(images[0][0].width * 2, images[0][0].height * 2)
)
image_inputs = processor.image_processor(images, return_tensors="pt")
else:
image_inputs = None
@ -237,10 +246,15 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@ -381,8 +395,9 @@ class VlmCausalLM(FlashCausalLM):
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
# TODO: remove the unsqueeze(0)
input_ids=input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
@ -394,6 +409,7 @@ class VlmCausalLM(FlashCausalLM):
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
@ -403,6 +419,8 @@ class VlmCausalLM(FlashCausalLM):
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph