feat: support qwen2.5 vl model

This commit is contained in:
drbh 2025-01-31 12:36:03 -05:00
parent c1cf36c0dc
commit 10aa62f87f
10 changed files with 933 additions and 3 deletions

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image showcases the Statue of Liberty, a colossal bronze statue located in New York Harbor, a heritage building in the United States. The statue has a majestic presence, with one arm raised towards the sun and the other hitched on her hip. It sits atop a keeper's walkway, observed from the water. Surrounding the statue is a lush green meadow, where picnic spots, walkways, and a visitor desk can be found. In front of the statue, a large marina can accommodate fourteen different kinds of boats. In the backdrop stands the Empire State Building, marking the crowded skyscrapers of New York City.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1738342753,
"id": "",
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.0.2-dev0-native",
"usage": {
"completion_tokens": 128,
"prompt_tokens": 8736,
"total_tokens": 8864
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image shows a whimsical scene set in what appears to be a fast-food restaurant. Dominating the foreground is a large, green, inflatable dinosaur with realistic textures, giving it a Jurassic Park-like appearance. The dinosaur is wearing a red Adult Swim logo hat, adding a humorous touch to its appearance.\n\nSurrounding the dinosaur are various food items typically found in a fast-food restaurant, including French fries in a plastic cup, a hamburger on a plate, and a beverage in another cup. The hamburger is detailed with lettuce, tomato, and other typical fast-food ingredients.\n\nAccompanying the dinosaur is a realistic-looking owl perched on the table, which adds to the surreal and playful atmosphere of the scene. The background features the interior of the restaurant with neon signs and other typical decor elements, enhancing the overall theme of a fun and fantastical fast-food experience.\n\nOverall, the image is a playful and imaginative blend of a standard fast-food setting with an unexpected and amusing twist provided by the dinosaur and owl characters.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1738343775,
"id": "",
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.0.2-dev0-native",
"usage": {
"completion_tokens": 206,
"prompt_tokens": 5375,
"total_tokens": 5581
}
}

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1738342872,
"id": "",
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
"object": "chat.completion",
"system_fingerprint": "3.0.2-dev0-native",
"usage": {
"completion_tokens": 121,
"prompt_tokens": 1363,
"total_tokens": 1484
}
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1738343559,
"id": "",
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.0.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,122 @@
import pytest
@pytest.fixture(scope="module")
def flash_qwen2_5_vl_handle(launcher):
with launcher("Qwen/Qwen2.5-VL-3B-Instruct") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_qwen2_5(flash_qwen2_5_vl_handle):
await flash_qwen2_5_vl_handle.health(300)
return flash_qwen2_5_vl_handle.client
@pytest.mark.private
async def test_flash_qwen2_5_vl_simple(flash_qwen2_5, response_snapshot):
response = await flash_qwen2_5.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
)
assert (
response.choices[0].message.content
== "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting."
)
assert response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_5_vl_simple_streaming(flash_qwen2_5, response_snapshot):
responses = await flash_qwen2_5.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
stream=True,
)
count = 0
generated = ""
last_response = None
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
assert (
generated
== "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting."
)
assert count == 121
assert last_response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_5_vl_bay(flash_qwen2_5, response_snapshot):
response = await flash_qwen2_5.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
)
assert response == response_snapshot
@pytest.mark.private
async def test_flash_qwen2_5_vl_inpaint(flash_qwen2_5, response_snapshot):
response = await flash_qwen2_5.chat(
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"
},
},
{"type": "text", "text": "Describe the image"},
],
},
],
)
assert response == response_snapshot

View File

@ -184,10 +184,43 @@ impl Qwen2Vl {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2_5VlVisionConfig {
pub(crate) depth: usize,
pub(crate) hidden_act: String,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_heads: usize,
pub(crate) in_chans: usize,
pub(crate) out_hidden_size: usize,
pub(crate) patch_size: usize,
pub(crate) spatial_merge_size: usize,
pub(crate) spatial_patch_size: usize,
pub(crate) window_size: usize,
pub(crate) fullatt_block_indexes: Vec<usize>,
pub(crate) tokens_per_second: usize,
pub(crate) temporal_patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2_5Vl {
pub(crate) vision_config: Qwen2_5VlVisionConfig,
}
impl Qwen2_5Vl {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
num_pixels / self.vision_config.patch_size.pow(2)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Config { pub enum Config {
Qwen2_5Vl(Qwen2_5Vl),
Qwen2Vl(Qwen2Vl), Qwen2Vl(Qwen2Vl),
LlavaNext(LlavaNext), LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),

View File

@ -684,6 +684,10 @@ fn image_tokens(
"<|vision_start|>{:?}<|vision_end|>", "<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width)) "<|image_pad|>".repeat(config.get_number_of_features(height, width))
), ),
Qwen2_5Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
_ => unimplemented!("Images tokens are not supported for this model configuration"), _ => unimplemented!("Images tokens are not supported for this model configuration"),
} }
} }
@ -712,7 +716,7 @@ fn prepare_input<T: TokenizerTrait>(
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_) config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
| Qwen2Vl(_)), | Qwen2Vl(_) | Qwen2_5Vl(_)),
) => { ) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -164,6 +164,9 @@ try:
from text_generation_server.models.custom_modeling.qwen2_vl import ( from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e: except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
@ -317,6 +320,11 @@ class ModelType(enum.Enum):
"name": "Qwen 2 VL", "name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
} }
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
OPT = { OPT = {
"type": "opt", "type": "opt",
"name": "Opt", "name": "Opt",
@ -1368,6 +1376,19 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
if model_type == QWEN2_5_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA: if model_type == MLLAMA:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return MllamaCausalLM( return MllamaCausalLM(

View File

@ -0,0 +1,641 @@
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2.5 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import flash_attn_2_cuda
import numpy as np
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
None, # tmp buffer (auto-allocated)
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None, # max_seqlen_q (auto-computed)
None, # max_seqlen_k (auto-computed)
None, # block_tables
None, # broadcast_mask
max_seqlen, # max_seqlen
max_seqlen, # max_seqlen
0.0, # dropout_p
self.softmax_scale,
False, # zero_tensors
causal, # causal attention within each sequence
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_cap
False, # deterministic
None, # rng_state
)[0]
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.up = TensorParallelColumnLinear.load(
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
)
self.gate = TensorParallelColumnLinear.load(
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
)
self.down = TensorParallelRowLinear.load(
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_states = self.gate(hidden_states)
up_states = self.up(hidden_states)
activated_states = self.activation_fn(gate_states) * up_states
down_states = self.down(activated_states)
return down_states
class Qwen2_5VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2_5VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastRMSNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastRMSNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2_5VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
hidden_states = hidden_states + mlp_out
return hidden_states
class Qwen2_5VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastRMSNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2_5VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_chans,
out_channels=config.hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.hidden_size // config.num_heads
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2_5VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2_5VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
self.window_size = config.window_size
self.patch_size = config.patch_size
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
self.fullatt_block_indexes = config.fullatt_block_indexes
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w
)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = (
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
)
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
seq_len = hidden_states.shape[0]
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
og_shape = (seq_len, -1)
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
class Qwen2_5VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2_5VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.device = weights.device
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
# (batch_size, 3)
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
# if image grid provided than we need to calculate the position ids
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
# capture vision segments
starts = torch.where(input_ids == vision_start_token_id)[0]
ends = torch.where(input_ids == vision_end_token_id)[0]
# ie. [[ 14, 2181], [2212, 4379]]
vision_segments = torch.stack((starts, ends), dim=1)
# capture text lengths as the space between vision segments
prev_end = torch.cat( # shift to the left to get the previous end
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
) # ie. [0, 2181]
# text is the space between the end of one vision segment and the start of the next
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
# calculate the max id from the image width for each segment
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
total_segment_lengths = vision_widths_max + text_lengths
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
text_diff = total_segment_lengths - text_lengths
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + total_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_diff[i]
for i, seq_len in enumerate(text_lengths)
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
# combine by alternating text and vision segments (text, vision, text, vision, ...)
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# the final segment is the difference between the last vision segment and the end of the input
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
pixel_values = pixel_values.to(inputs_embeds.dtype)
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
inputs_embeds[input_ids == self.image_token_id] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -123,6 +123,11 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>" return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -231,7 +236,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image = Image.open(BytesIO(chunk.image.data)) image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20 # default warmup image is 20x20
if config.model_type == "qwen2_vl": if (
config.model_type == "qwen2_vl"
or config.model_type == "qwen2_5_vl"
):
if image.width <= 20: if image.width <= 20:
w = image.width * 2 w = image.width * 2
h = image.height * 2 h = image.height * 2
@ -422,7 +430,10 @@ class VlmCausalLM(FlashCausalLM):
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if self.model.config.model_type == "qwen2_vl": if (
self.model.config.model_type == "qwen2_vl"
or self.model.config.model_type == "qwen2_5_vl"
):
if position_ids.dim() == 1 and batch.prefilling: if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids( position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw input_ids, batch.image_grid_thw