diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_bay.json b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_bay.json new file mode 100644 index 00000000..739d5361 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_bay.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image showcases the Statue of Liberty, a colossal bronze statue located in New York Harbor, a heritage building in the United States. The statue has a majestic presence, with one arm raised towards the sun and the other hitched on her hip. It sits atop a keeper's walkway, observed from the water. Surrounding the statue is a lush green meadow, where picnic spots, walkways, and a visitor desk can be found. In front of the statue, a large marina can accommodate fourteen different kinds of boats. In the backdrop stands the Empire State Building, marking the crowded skyscrapers of New York City.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1738342753, + "id": "", + "model": "Qwen/Qwen2.5-VL-3B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.0.2-dev0-native", + "usage": { + "completion_tokens": 128, + "prompt_tokens": 8736, + "total_tokens": 8864 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_inpaint.json b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_inpaint.json new file mode 100644 index 00000000..00e1b041 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_inpaint.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image shows a whimsical scene set in what appears to be a fast-food restaurant. Dominating the foreground is a large, green, inflatable dinosaur with realistic textures, giving it a Jurassic Park-like appearance. The dinosaur is wearing a red Adult Swim logo hat, adding a humorous touch to its appearance.\n\nSurrounding the dinosaur are various food items typically found in a fast-food restaurant, including French fries in a plastic cup, a hamburger on a plate, and a beverage in another cup. The hamburger is detailed with lettuce, tomato, and other typical fast-food ingredients.\n\nAccompanying the dinosaur is a realistic-looking owl perched on the table, which adds to the surreal and playful atmosphere of the scene. The background features the interior of the restaurant with neon signs and other typical decor elements, enhancing the overall theme of a fun and fantastical fast-food experience.\n\nOverall, the image is a playful and imaginative blend of a standard fast-food setting with an unexpected and amusing twist provided by the dinosaur and owl characters.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1738343775, + "id": "", + "model": "Qwen/Qwen2.5-VL-3B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.0.2-dev0-native", + "usage": { + "completion_tokens": 206, + "prompt_tokens": 5375, + "total_tokens": 5581 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple.json b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple.json new file mode 100644 index 00000000..c498c36a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1738342872, + "id": "", + "model": "Qwen/Qwen2.5-VL-3B-Instruct", + "object": "chat.completion", + "system_fingerprint": "3.0.2-dev0-native", + "usage": { + "completion_tokens": 121, + "prompt_tokens": 1363, + "total_tokens": 1484 + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple_streaming.json b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple_streaming.json new file mode 100644 index 00000000..0d2718a7 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple_streaming.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": "", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1738343559, + "id": "", + "model": "Qwen/Qwen2.5-VL-3B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "3.0.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_flash_qwen2_5_vl.py b/integration-tests/models/test_flash_qwen2_5_vl.py new file mode 100644 index 00000000..922068f4 --- /dev/null +++ b/integration-tests/models/test_flash_qwen2_5_vl.py @@ -0,0 +1,122 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_qwen2_5_vl_handle(launcher): + with launcher("Qwen/Qwen2.5-VL-3B-Instruct") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_qwen2_5(flash_qwen2_5_vl_handle): + await flash_qwen2_5_vl_handle.health(300) + return flash_qwen2_5_vl_handle.client + + +@pytest.mark.private +async def test_flash_qwen2_5_vl_simple(flash_qwen2_5, response_snapshot): + response = await flash_qwen2_5.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + ) + + assert ( + response.choices[0].message.content + == "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting." + ) + + assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_5_vl_simple_streaming(flash_qwen2_5, response_snapshot): + responses = await flash_qwen2_5.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + stream=True, + ) + + count = 0 + generated = "" + last_response = None + async for response in responses: + count += 1 + generated += response.choices[0].delta.content + last_response = response + + assert ( + generated + == "The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting." + ) + assert count == 121 + assert last_response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_5_vl_bay(flash_qwen2_5, response_snapshot): + response = await flash_qwen2_5.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + ) + assert response == response_snapshot + + +@pytest.mark.private +async def test_flash_qwen2_5_vl_inpaint(flash_qwen2_5, response_snapshot): + response = await flash_qwen2_5.chat( + seed=42, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" + }, + }, + {"type": "text", "text": "Describe the image"}, + ], + }, + ], + ) + assert response == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index a1ac107a..a0135984 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -184,10 +184,43 @@ impl Qwen2Vl { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2_5VlVisionConfig { + pub(crate) depth: usize, + pub(crate) hidden_act: String, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_heads: usize, + pub(crate) in_chans: usize, + pub(crate) out_hidden_size: usize, + pub(crate) patch_size: usize, + pub(crate) spatial_merge_size: usize, + pub(crate) spatial_patch_size: usize, + pub(crate) window_size: usize, + pub(crate) fullatt_block_indexes: Vec, + pub(crate) tokens_per_second: usize, + pub(crate) temporal_patch_size: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2_5Vl { + pub(crate) vision_config: Qwen2_5VlVisionConfig, +} + +impl Qwen2_5Vl { + pub fn get_number_of_features(&self, height: usize, width: usize) -> usize { + let num_pixels = height * width; + num_pixels / self.vision_config.patch_size.pow(2) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub enum Config { + Qwen2_5Vl(Qwen2_5Vl), Qwen2Vl(Qwen2Vl), LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), diff --git a/router/src/validation.rs b/router/src/validation.rs index 7ac05b21..c9a44e4a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -684,6 +684,10 @@ fn image_tokens( "<|vision_start|>{:?}<|vision_end|>", "<|image_pad|>".repeat(config.get_number_of_features(height, width)) ), + Qwen2_5Vl(config) => format!( + "<|vision_start|>{:?}<|vision_end|>", + "<|image_pad|>".repeat(config.get_number_of_features(height, width)) + ), _ => unimplemented!("Images tokens are not supported for this model configuration"), } } @@ -712,7 +716,7 @@ fn prepare_input( let (tokenizer_query, input_chunks) = match config { Some( config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_) - | Qwen2Vl(_)), + | Qwen2Vl(_) | Qwen2_5Vl(_)), ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f8150b5e..8cd33660 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -164,6 +164,9 @@ try: from text_generation_server.models.custom_modeling.qwen2_vl import ( Qwen2VLForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.qwen2_5_vl import ( + Qwen2_5VLForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") @@ -317,6 +320,11 @@ class ModelType(enum.Enum): "name": "Qwen 2 VL", "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", } + QWEN2_5_VL = { + "type": "qwen2_5_vl", + "name": "Qwen 2.5 VL", + "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", + } OPT = { "type": "opt", "name": "Opt", @@ -1368,6 +1376,19 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + if model_type == QWEN2_5_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2_5VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + default_dtype=torch.bfloat16, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py new file mode 100644 index 00000000..ad2f6039 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -0,0 +1,641 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2.5 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import flash_attn_2_cuda + +import numpy as np + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelEmbedding, + SpeculativeHead, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5VLAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.hidden_size // weights.process_group.size() + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=False, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0) + + self.proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + + # calc maximum sequence length for any batch + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + causal = False + + # execute flash attention + if SYSTEM == "ipex": + attn_output = torch.empty_like(query) + ipex.llm.functional.varlen_attention( + (query.contiguous() if query.device.type == "xpu" else query), + (key.contiguous() if key.device.type == "xpu" else key), + (value.contiguous() if value.device.type == "xpu" else value), + attn_output, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + attn_output = flash_attn_2_cuda.varlen_fwd( + query, + key, + value, + None, # tmp buffer (auto-allocated) + cu_seqlens, # cu_seqlens_q + cu_seqlens, # cu_seqlens_k + None, # max_seqlen_q (auto-computed) + None, # max_seqlen_k (auto-computed) + None, # block_tables + None, # broadcast_mask + max_seqlen, # max_seqlen + max_seqlen, # max_seqlen + 0.0, # dropout_p + self.softmax_scale, + False, # zero_tensors + causal, # causal attention within each sequence + -1, # window_size_left + -1, # window_size_right + 0.0, # softmax_cap + False, # deterministic + None, # rng_state + )[0] + + # reshape output to original dimensions + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + self.up = TensorParallelColumnLinear.load( + prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True + ) + self.gate = TensorParallelColumnLinear.load( + prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True + ) + self.down = TensorParallelRowLinear.load( + prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_states = self.gate(hidden_states) + up_states = self.up(hidden_states) + activated_states = self.activation_fn(gate_states) * up_states + down_states = self.down(activated_states) + return down_states + + +class Qwen2_5VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2_5VLAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastRMSNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastRMSNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2_5VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward( + self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen + ) -> torch.Tensor: + norm1_out, _ = self.norm1(hidden_states) + attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen) + hidden_states = hidden_states + attn_out + norm2_out, _ = self.norm2(hidden_states) + mlp_out = self.mlp(norm2_out) + hidden_states = hidden_states + mlp_out + return hidden_states + + +class Qwen2_5VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.patch_merger_ln_q = FastRMSNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2_5VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.hidden_size // config.num_heads + + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2_5VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2_5VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + self.window_size = config.window_size + self.patch_size = config.patch_size + self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size + self.fullatt_block_indexes = config.fullatt_block_indexes + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, + pixel_values: torch.Tensor, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + seq_len = hidden_states.shape[0] + patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + og_shape = (seq_len, -1) + + hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view( + og_shape + ) + rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view( + og_shape + ) + + rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device) + + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]) + + # iterately apply the blocks to the hidden states + for layer_num, block in enumerate(self.blocks): + # NOTE: qwen2_5_vl.py has a concept of full attention blocks + # that are applied at specific layers. + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = block( + hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen + ) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + +class Qwen2_5VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + # set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly + # returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment + config.rope_scaling.update({"rope_type": "mrope"}) + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.vision_end_token_id = config.vision_end_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.visual = Qwen2_5VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) + self.device = weights.device + + def get_position_ids( + self, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if image_grid_thw is None: + # (batch_size, 3) + return ( + torch.arange(input_ids.shape[0], device=input_ids.device) + .unsqueeze(1) + .repeat(1, 3) + ) + + # if image grid provided than we need to calculate the position ids + spatial_merge_size = self.spatial_merge_size + vision_start_token_id = self.vision_start_token_id + vision_end_token_id = self.vision_end_token_id + + device = input_ids.device + dtype = input_ids.dtype + input_ids_len = input_ids.shape[0] + + # capture vision segments + starts = torch.where(input_ids == vision_start_token_id)[0] + ends = torch.where(input_ids == vision_end_token_id)[0] + # ie. [[ 14, 2181], [2212, 4379]] + vision_segments = torch.stack((starts, ends), dim=1) + # capture text lengths as the space between vision segments + + prev_end = torch.cat( # shift to the left to get the previous end + [torch.zeros(1, device=ends.device, dtype=dtype), ends[:-1]] + ) # ie. [0, 2181] + + # text is the space between the end of one vision segment and the start of the next + text_lengths = vision_segments[:, 0] - prev_end + 1 # ie. [15, 32] + + # calculate the max id from the image width for each segment + vision_widths_max = torch.cat( + [ + torch.zeros(1, device=image_grid_thw.device, dtype=dtype), + image_grid_thw[:-1, 2] // spatial_merge_size, + ] + ) + total_segment_lengths = vision_widths_max + text_lengths + total_segment_lengths = total_segment_lengths.cumsum(dim=0) + text_diff = total_segment_lengths - text_lengths + + # create position ids for each vision segment based on the image grid + llm_pos_ids_list = [] + for i, _ in enumerate(vision_segments): + t, h, w = ( + image_grid_thw[i][0], + image_grid_thw[i][1] // spatial_merge_size, + image_grid_thw[i][2] // spatial_merge_size, + ) + t_indices = torch.arange(t, device=device).repeat_interleave(h * w) + h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) + w_indices = torch.arange(w, device=device).repeat(t * h) + image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) + + # offset by the position of the last vision segment + im = image_position_ids + total_segment_lengths[i] + llm_pos_ids_list.append(im) + + # create position ids for each text segment + text_ranges = [ + torch.arange(seq_len, device=device).view(1, -1).expand(3, -1) + + text_diff[i] + for i, seq_len in enumerate(text_lengths) + ] # ie. [[ 0, 1, ..., 14], [2182, 2183, ..., 2213]] + + # combine by alternating text and vision segments (text, vision, text, vision, ...) + full_llm_pos_ids_list = [ + item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist + ] + + # the final segment is the difference between the last vision segment and the end of the input + max_s = full_llm_pos_ids_list[-1].max() + 1 + final_text_len = input_ids_len - ends[-1] + if final_text_len > 0: + m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) + full_llm_pos_ids_list.append(m + max_s) + + # concat and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3) + position_ids = ( + torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) + ) + return position_ids + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + # Unused in this model + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + inputs_embeds = self.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + pixel_values = pixel_values.to(inputs_embeds.dtype) + if pixel_values is not None: + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw + ).squeeze(0) + inputs_embeds[input_ids == self.image_token_id] = image_embeds + + hidden_states = self.text_model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index db78341d..23fdca05 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -123,6 +123,11 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads return f"<|vision_start|>{padding}<|vision_end|>" + elif config.model_type == "qwen2_5_vl": + grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id] + num_pads = grid_t * grid_h * grid_w // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -231,7 +236,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image = Image.open(BytesIO(chunk.image.data)) # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the # default warmup image is 20x20 - if config.model_type == "qwen2_vl": + if ( + config.model_type == "qwen2_vl" + or config.model_type == "qwen2_5_vl" + ): if image.width <= 20: w = image.width * 2 h = image.height * 2 @@ -422,7 +430,10 @@ class VlmCausalLM(FlashCausalLM): max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices - if self.model.config.model_type == "qwen2_vl": + if ( + self.model.config.model_type == "qwen2_vl" + or self.model.config.model_type == "qwen2_5_vl" + ): if position_ids.dim() == 1 and batch.prefilling: position_ids = self.model.get_position_ids( input_ids, batch.image_grid_thw