mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
feat: support qwen2.5 vl model
This commit is contained in:
parent
c1cf36c0dc
commit
10aa62f87f
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image showcases the Statue of Liberty, a colossal bronze statue located in New York Harbor, a heritage building in the United States. The statue has a majestic presence, with one arm raised towards the sun and the other hitched on her hip. It sits atop a keeper's walkway, observed from the water. Surrounding the statue is a lush green meadow, where picnic spots, walkways, and a visitor desk can be found. In front of the statue, a large marina can accommodate fourteen different kinds of boats. In the backdrop stands the Empire State Building, marking the crowded skyscrapers of New York City.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1738342753,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.0.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 128,
|
||||||
|
"prompt_tokens": 8736,
|
||||||
|
"total_tokens": 8864
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image shows a whimsical scene set in what appears to be a fast-food restaurant. Dominating the foreground is a large, green, inflatable dinosaur with realistic textures, giving it a Jurassic Park-like appearance. The dinosaur is wearing a red Adult Swim logo hat, adding a humorous touch to its appearance.\n\nSurrounding the dinosaur are various food items typically found in a fast-food restaurant, including French fries in a plastic cup, a hamburger on a plate, and a beverage in another cup. The hamburger is detailed with lettuce, tomato, and other typical fast-food ingredients.\n\nAccompanying the dinosaur is a realistic-looking owl perched on the table, which adds to the surreal and playful atmosphere of the scene. The background features the interior of the restaurant with neon signs and other typical decor elements, enhancing the overall theme of a fun and fantastical fast-food experience.\n\nOverall, the image is a playful and imaginative blend of a standard fast-food setting with an unexpected and amusing twist provided by the dinosaur and owl characters.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1738343775,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.0.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 206,
|
||||||
|
"prompt_tokens": 5375,
|
||||||
|
"total_tokens": 5581
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1738342872,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "3.0.2-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 121,
|
||||||
|
"prompt_tokens": 1363,
|
||||||
|
"total_tokens": 1484
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1738343559,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "3.0.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
122
integration-tests/models/test_flash_qwen2_5_vl.py
Normal file
122
integration-tests/models/test_flash_qwen2_5_vl.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_qwen2_5_vl_handle(launcher):
|
||||||
|
with launcher("Qwen/Qwen2.5-VL-3B-Instruct") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_qwen2_5(flash_qwen2_5_vl_handle):
|
||||||
|
await flash_qwen2_5_vl_handle.health(300)
|
||||||
|
return flash_qwen2_5_vl_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_5_vl_simple(flash_qwen2_5, response_snapshot):
|
||||||
|
response = await flash_qwen2_5.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe the image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_5_vl_simple_streaming(flash_qwen2_5, response_snapshot):
|
||||||
|
responses = await flash_qwen2_5.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe the image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
|
||||||
|
assert (
|
||||||
|
generated
|
||||||
|
== "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting."
|
||||||
|
)
|
||||||
|
assert count == 121
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_5_vl_bay(flash_qwen2_5, response_snapshot):
|
||||||
|
response = await flash_qwen2_5.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe the image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_qwen2_5_vl_inpaint(flash_qwen2_5, response_snapshot):
|
||||||
|
response = await flash_qwen2_5.chat(
|
||||||
|
seed=42,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Describe the image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -184,10 +184,43 @@ impl Qwen2Vl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Qwen2_5VlVisionConfig {
|
||||||
|
pub(crate) depth: usize,
|
||||||
|
pub(crate) hidden_act: String,
|
||||||
|
pub(crate) hidden_size: usize,
|
||||||
|
pub(crate) intermediate_size: usize,
|
||||||
|
pub(crate) num_heads: usize,
|
||||||
|
pub(crate) in_chans: usize,
|
||||||
|
pub(crate) out_hidden_size: usize,
|
||||||
|
pub(crate) patch_size: usize,
|
||||||
|
pub(crate) spatial_merge_size: usize,
|
||||||
|
pub(crate) spatial_patch_size: usize,
|
||||||
|
pub(crate) window_size: usize,
|
||||||
|
pub(crate) fullatt_block_indexes: Vec<usize>,
|
||||||
|
pub(crate) tokens_per_second: usize,
|
||||||
|
pub(crate) temporal_patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Qwen2_5Vl {
|
||||||
|
pub(crate) vision_config: Qwen2_5VlVisionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen2_5Vl {
|
||||||
|
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||||
|
let num_pixels = height * width;
|
||||||
|
num_pixels / self.vision_config.patch_size.pow(2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum Config {
|
pub enum Config {
|
||||||
|
Qwen2_5Vl(Qwen2_5Vl),
|
||||||
Qwen2Vl(Qwen2Vl),
|
Qwen2Vl(Qwen2Vl),
|
||||||
LlavaNext(LlavaNext),
|
LlavaNext(LlavaNext),
|
||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
|
@ -684,6 +684,10 @@ fn image_tokens(
|
|||||||
"<|vision_start|>{:?}<|vision_end|>",
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||||
),
|
),
|
||||||
|
Qwen2_5Vl(config) => format!(
|
||||||
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
|
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||||
|
),
|
||||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -712,7 +716,7 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
||||||
| Qwen2Vl(_)),
|
| Qwen2Vl(_) | Qwen2_5Vl(_)),
|
||||||
) => {
|
) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -164,6 +164,9 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||||
|
Qwen2_5VLForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||||
@ -317,6 +320,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Qwen 2 VL",
|
"name": "Qwen 2 VL",
|
||||||
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||||
}
|
}
|
||||||
|
QWEN2_5_VL = {
|
||||||
|
"type": "qwen2_5_vl",
|
||||||
|
"name": "Qwen 2.5 VL",
|
||||||
|
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
|
||||||
|
}
|
||||||
OPT = {
|
OPT = {
|
||||||
"type": "opt",
|
"type": "opt",
|
||||||
"name": "Opt",
|
"name": "Opt",
|
||||||
@ -1368,6 +1376,19 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
if model_type == QWEN2_5_VL:
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Qwen2_5VLForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
if model_type == MLLAMA:
|
if model_type == MLLAMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return MllamaCausalLM(
|
return MllamaCausalLM(
|
||||||
|
@ -0,0 +1,641 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Qwen2.5 VL model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
else:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||||
|
Qwen2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(
|
||||||
|
tensor: torch.Tensor, freqs: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
tensor = tensor.float()
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||||
|
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||||
|
output = output.to(orig_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size // weights.process_group.size()
|
||||||
|
self.head_dim = config.hidden_size // config.num_heads
|
||||||
|
self.num_heads = config.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.qkv",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_heads,
|
||||||
|
)
|
||||||
|
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||||
|
|
||||||
|
self.proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor,
|
||||||
|
max_seqlen: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# apply the qkv linear layer to the hidden state
|
||||||
|
qkv = self.qkv(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# reshape the query, key, and value tensors
|
||||||
|
_shape = (
|
||||||
|
hidden_state.shape[0],
|
||||||
|
self.num_heads,
|
||||||
|
self.embed_dim // self.num_heads,
|
||||||
|
)
|
||||||
|
query = query.view(*_shape)
|
||||||
|
key = key.view(*_shape)
|
||||||
|
value = value.view(*_shape)
|
||||||
|
|
||||||
|
# apply rotary positional embeddings
|
||||||
|
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
# calc maximum sequence length for any batch
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# execute flash attention
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
(query.contiguous() if query.device.type == "xpu" else query),
|
||||||
|
(key.contiguous() if key.device.type == "xpu" else key),
|
||||||
|
(value.contiguous() if value.device.type == "xpu" else value),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None, # tmp buffer (auto-allocated)
|
||||||
|
cu_seqlens, # cu_seqlens_q
|
||||||
|
cu_seqlens, # cu_seqlens_k
|
||||||
|
None, # max_seqlen_q (auto-computed)
|
||||||
|
None, # max_seqlen_k (auto-computed)
|
||||||
|
None, # block_tables
|
||||||
|
None, # broadcast_mask
|
||||||
|
max_seqlen, # max_seqlen
|
||||||
|
max_seqlen, # max_seqlen
|
||||||
|
0.0, # dropout_p
|
||||||
|
self.softmax_scale,
|
||||||
|
False, # zero_tensors
|
||||||
|
causal, # causal attention within each sequence
|
||||||
|
-1, # window_size_left
|
||||||
|
-1, # window_size_right
|
||||||
|
0.0, # softmax_cap
|
||||||
|
False, # deterministic
|
||||||
|
None, # rng_state
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# reshape output to original dimensions
|
||||||
|
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.gate = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.down = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_states = self.gate(hidden_states)
|
||||||
|
up_states = self.up(hidden_states)
|
||||||
|
activated_states = self.activation_fn(gate_states) * up_states
|
||||||
|
down_states = self.down(activated_states)
|
||||||
|
return down_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLVisionBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Qwen2_5VLAttention(
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.norm1 = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm1",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.norm2 = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm2",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.mlp = Qwen2_5VLVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||||
|
) -> torch.Tensor:
|
||||||
|
norm1_out, _ = self.norm1(hidden_states)
|
||||||
|
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||||
|
hidden_states = hidden_states + attn_out
|
||||||
|
norm2_out, _ = self.norm2(hidden_states)
|
||||||
|
mlp_out = self.mlp(norm2_out)
|
||||||
|
hidden_states = hidden_states + mlp_out
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLPatchMerger(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
||||||
|
self.patch_merger_ln_q = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.ln_q",
|
||||||
|
weights=weights,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = F.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.spatial_merge_size = config.spatial_merge_size
|
||||||
|
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||||
|
self.patch_embedding = nn.Conv3d(
|
||||||
|
in_channels=config.in_chans,
|
||||||
|
out_channels=config.hidden_size,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
head_dim = config.hidden_size // config.num_heads
|
||||||
|
|
||||||
|
theta = 10000.0
|
||||||
|
dim = head_dim // 2
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Qwen2_5VLVisionBlock(
|
||||||
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
for i in range(config.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.merger = Qwen2_5VLPatchMerger(
|
||||||
|
prefix=f"{prefix}.merger",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.temporal_patch_size = config.temporal_patch_size
|
||||||
|
self.spatial_patch_size = config.spatial_patch_size
|
||||||
|
self.in_channels = config.in_channels
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.window_size = config.window_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
|
||||||
|
self.fullatt_block_indexes = config.fullatt_block_indexes
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index: list = []
|
||||||
|
cu_window_seqlens: list = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
vit_merger_window_size = (
|
||||||
|
self.window_size // self.spatial_merge_size // self.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h, llm_grid_w = (
|
||||||
|
grid_h // self.spatial_merge_size,
|
||||||
|
grid_w // self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||||
|
grid_t, llm_grid_h, llm_grid_w
|
||||||
|
)
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h,
|
||||||
|
vit_merger_window_size,
|
||||||
|
num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h * num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
|
index_padded = index_padded.reshape(-1)
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
cu_seqlens_tmp = (
|
||||||
|
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||||
|
)
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
window_index = torch.cat(window_index, dim=0)
|
||||||
|
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# reshape the input tensor for processing
|
||||||
|
shape = (
|
||||||
|
-1,
|
||||||
|
self.in_channels,
|
||||||
|
self.temporal_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
self.spatial_patch_size,
|
||||||
|
)
|
||||||
|
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||||
|
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||||
|
# TODO: revisit to see if we can avoid some of these reshapes
|
||||||
|
|
||||||
|
# find the position ids for the input tensor based on the grid_thw
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||||
|
wpos_ids = wpos_ids.flatten()
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
|
||||||
|
# apply the positional embeddings to the position ids
|
||||||
|
seq = torch.arange(
|
||||||
|
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||||
|
seq_len = hidden_states.shape[0]
|
||||||
|
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
og_shape = (seq_len, -1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
|
||||||
|
og_shape
|
||||||
|
)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
|
||||||
|
og_shape
|
||||||
|
)
|
||||||
|
|
||||||
|
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
|
||||||
|
|
||||||
|
cu_window_seqlens = torch.tensor(
|
||||||
|
cu_window_seqlens,
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
|
|
||||||
|
# create a cu_seqlens tensor to be used in the attention mask
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||||
|
|
||||||
|
# iterately apply the blocks to the hidden states
|
||||||
|
for layer_num, block in enumerate(self.blocks):
|
||||||
|
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
|
||||||
|
# that are applied at specific layers.
|
||||||
|
if layer_num in self.fullatt_block_indexes:
|
||||||
|
cu_seqlens_now = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_now = cu_window_seqlens
|
||||||
|
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply the final patch merger to the hidden states
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
reverse_indices = torch.argsort(window_index)
|
||||||
|
hidden_states = hidden_states[reverse_indices, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||||
|
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
|
||||||
|
config.rope_scaling.update({"rope_type": "mrope"})
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.vision_start_token_id = config.vision_start_token_id
|
||||||
|
self.vision_end_token_id = config.vision_end_token_id
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.video_token_id = config.video_token_id
|
||||||
|
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.visual = Qwen2_5VisionModel(
|
||||||
|
prefix="visual", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
def get_position_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if image_grid_thw is None:
|
||||||
|
# (batch_size, 3)
|
||||||
|
return (
|
||||||
|
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(1, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if image grid provided than we need to calculate the position ids
|
||||||
|
spatial_merge_size = self.spatial_merge_size
|
||||||
|
vision_start_token_id = self.vision_start_token_id
|
||||||
|
vision_end_token_id = self.vision_end_token_id
|
||||||
|
|
||||||
|
device = input_ids.device
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
input_ids_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
# capture vision segments
|
||||||
|
starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||||
|
ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||||
|
# ie. [[ 14, 2181], [2212, 4379]]
|
||||||
|
vision_segments = torch.stack((starts, ends), dim=1)
|
||||||
|
# capture text lengths as the space between vision segments
|
||||||
|
|
||||||
|
prev_end = torch.cat( # shift to the left to get the previous end
|
||||||
|
[torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]]
|
||||||
|
) # ie. [0, 2181]
|
||||||
|
|
||||||
|
# text is the space between the end of one vision segment and the start of the next
|
||||||
|
text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32]
|
||||||
|
|
||||||
|
# calculate the max id from the image width for each segment
|
||||||
|
vision_widths_max = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||||
|
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
total_segment_lengths = vision_widths_max + text_lengths
|
||||||
|
total_segment_lengths = total_segment_lengths.cumsum(dim=0)
|
||||||
|
text_diff = total_segment_lengths - text_lengths
|
||||||
|
|
||||||
|
# create position ids for each vision segment based on the image grid
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
for i, _ in enumerate(vision_segments):
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[i][0],
|
||||||
|
image_grid_thw[i][1] // spatial_merge_size,
|
||||||
|
image_grid_thw[i][2] // spatial_merge_size,
|
||||||
|
)
|
||||||
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||||
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||||
|
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||||
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||||
|
|
||||||
|
# offset by the position of the last vision segment
|
||||||
|
im = image_position_ids + total_segment_lengths[i]
|
||||||
|
llm_pos_ids_list.append(im)
|
||||||
|
|
||||||
|
# create position ids for each text segment
|
||||||
|
text_ranges = [
|
||||||
|
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
+ text_diff[i]
|
||||||
|
for i, seq_len in enumerate(text_lengths)
|
||||||
|
] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]]
|
||||||
|
|
||||||
|
# combine by alternating text and vision segments (text, vision, text, vision, ...)
|
||||||
|
full_llm_pos_ids_list = [
|
||||||
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||||
|
]
|
||||||
|
|
||||||
|
# the final segment is the difference between the last vision segment and the end of the input
|
||||||
|
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||||
|
final_text_len = input_ids_len - ends[-1]
|
||||||
|
if final_text_len > 0:
|
||||||
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||||
|
full_llm_pos_ids_list.append(m + max_s)
|
||||||
|
|
||||||
|
# concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||||
|
position_ids = (
|
||||||
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||||
|
)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
# Unused in this model
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# apply the visual model to the pixel values if they are provided
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
pixel_values = pixel_values.to(inputs_embeds.dtype)
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.visual(
|
||||||
|
pixel_values, grid_thw=image_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||||
|
|
||||||
|
hidden_states = self.text_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -123,6 +123,11 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
elif config.model_type == "qwen2_5_vl":
|
||||||
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||||
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
|
padding = "<|image_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -231,7 +236,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image = Image.open(BytesIO(chunk.image.data))
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
||||||
# default warmup image is 20x20
|
# default warmup image is 20x20
|
||||||
if config.model_type == "qwen2_vl":
|
if (
|
||||||
|
config.model_type == "qwen2_vl"
|
||||||
|
or config.model_type == "qwen2_5_vl"
|
||||||
|
):
|
||||||
if image.width <= 20:
|
if image.width <= 20:
|
||||||
w = image.width * 2
|
w = image.width * 2
|
||||||
h = image.height * 2
|
h = image.height * 2
|
||||||
@ -422,7 +430,10 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if self.model.config.model_type == "qwen2_vl":
|
if (
|
||||||
|
self.model.config.model_type == "qwen2_vl"
|
||||||
|
or self.model.config.model_type == "qwen2_5_vl"
|
||||||
|
):
|
||||||
if position_ids.dim() == 1 and batch.prefilling:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids, batch.image_grid_thw
|
input_ids, batch.image_grid_thw
|
||||||
|
Loading…
Reference in New Issue
Block a user