From d96eef2a026f90da9c51dd855fd5c7053c58ab45 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Thu, 24 Oct 2024 15:36:53 +0000 Subject: [PATCH] feat: add support for qwen2 vl model --- .../text_generation_server/layers/rotary.py | 28 + .../layers/tensor_parallel.py | 2 +- .../text_generation_server/models/__init__.py | 20 + .../custom_modeling/flash_qwen2_modeling.py | 43 +- .../models/custom_modeling/qwen2_vl.py | 544 ++++++++++++++++++ .../models/vlm_causal_lm.py | 22 +- 6 files changed, 653 insertions(+), 6 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/qwen2_vl.py diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index a2076bb2..6e2ba228 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -89,6 +89,8 @@ class PositionRotaryEmbedding(nn.Module): if rope_type == "linear": pass + elif rope_type == "default": + pass elif rope_type == "dynamic": scaling_factor = rope_scaling["factor"] return DynamicPositionRotaryEmbedding( @@ -275,6 +277,32 @@ class PositionRotaryEmbedding(nn.Module): # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) + def get_cos_sin_hack( + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + ): + # TODO: avoid always computing, use the cache and update it if necessary + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 2, 3 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + + # Update cached values + self._cos_cached = cos + self._sin_cached = sin + + return cos, sin + class SuRotaryEmbedding(PositionRotaryEmbedding): def __init__( diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 13f12ef1..febf28da 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -144,7 +144,7 @@ class TensorParallelColumnLinear(SuperLayer): num_key_value_heads=num_key_value_heads, ) if bias: - raise NotImplementedError("packed_qkv only implemented for baichuan") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None linear = get_linear(weight, bias) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 99e3d343..6c633521 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -146,6 +146,9 @@ try: from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) + from text_generation_server.models.custom_modeling.qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") @@ -275,6 +278,11 @@ class ModelType(enum.Enum): "name": "Qwen 2", "url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f", } + QWEN2_VL = { + "type": "qwen2_vl", + "name": "Qwen 2 VL", + "url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d", + } OPT = { "type": "opt", "name": "Opt", @@ -1193,6 +1201,18 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == QWEN2_VL: + return VlmCausalLM( + model_id=model_id, + model_class=Qwen2VLForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) if model_type == MLLAMA: if FLASH_ATTENTION: return MllamaCausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index ab2a177d..e9be22b1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -49,6 +49,13 @@ def _load_gqa(config, prefix: str, weights): ) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + class Qwen2Attention(torch.nn.Module): def __init__( self, @@ -61,6 +68,7 @@ class Qwen2Attention(torch.nn.Module): config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads + self.mrope_section = config.rope_scaling.get("mrope_section", None) self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -122,7 +130,28 @@ class Qwen2Attention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + # TODO: correctly handle the multimodal case + if False: + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + else: + # multimodal rotary + unsqueeze_dim = 1 + mrope_section = self.mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], + dim=-1, + ).unsqueeze(unsqueeze_dim) + + _query = query.transpose(0, 1).unsqueeze(0) + _key = torch.select(kv, dim=1, index=0).transpose(0, 1).unsqueeze(0) + q_embed = (_query * cos) + (rotate_half(_query) * sin) + k_embed = (_key * cos) + (rotate_half(_key) * sin) + query = q_embed.squeeze(0).transpose(0, 1) + kv[:, 0] = k_embed.squeeze(0).transpose(0, 1) if prefill_cache_indices is not None: kv_to_cache = kv[prefill_cache_indices] @@ -306,12 +335,20 @@ class Qwen2Model(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + + # if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids + if inputs_embeds is not None: + hidden_states = inputs_embeds.squeeze(0) + else: + hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + # TODO: fix how N-D position_ids are handled + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( position_ids, true_max_s, hidden_states.dtype ) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py new file mode 100644 index 00000000..28217d7d --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -0,0 +1,544 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 VL model.""" + +from typing import Optional, Tuple, List + +import torch +import torch.utils.checkpoint +from torch import nn +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import flash_attn_2_cuda + +from transformers.activations import ACT2FN +import torch.nn.functional as F + +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.layers import ( + TensorParallelColumnLinear, + FastLinear, +) +from text_generation_server.layers.attention import ( + Seqlen, +) +from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2Model, +) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2VLSdpaAttention(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.embed_dim = config.embed_dim + self.head_dim = config.hidden_size // config.num_heads + self.num_heads = config.num_heads // weights.process_group.size() + + self.qkv = TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv", + weights=weights, + bias=True, + num_heads=self.num_heads, + num_key_value_heads=self.num_heads, + ) + + self.proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.proj", + weights=weights, + bias=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # apply the qkv linear layer to the hidden state + qkv = self.qkv(hidden_state) + query, key, value = qkv.split( + [self.embed_dim, self.embed_dim, self.embed_dim], dim=1 + ) + + # reshape the query, key, and value tensors + _shape = ( + hidden_state.shape[0], + self.num_heads, + self.embed_dim // self.num_heads, + ) + query = query.view(*_shape) + key = key.view(*_shape) + value = value.view(*_shape) + + # apply rotary positional embeddings + query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze( + 0 + ) + key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0) + # TODO: make use of existing RotatoryPositionEmbedding class + + # create the attention mask + attention_mask = torch.zeros( + [1, hidden_state.shape[0], hidden_state.shape[0]], + device=hidden_state.device, + dtype=torch.bool, + ) + # TODO: avoid creating the mask in the forward pass, instead define the largest possible mask and slice it + + # apply the cu_seqlens to the attention mask + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + + # transpose for the attention mechanism (batch, seqlen, hidden_dim) -> (seqlen, batch, hidden_dim) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # apply attention + attn_output = F.scaled_dot_product_attention( + query, key, value, attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(hidden_state.shape[0], -1) + # TODO: prefer flash attention + + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2VLVisionMLP(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.attn = Qwen2VLSdpaAttention( + prefix=f"{prefix}.attn", + config=config, + weights=weights, + ) + self.norm1 = FastLayerNorm.load( + prefix=f"{prefix}.norm1", + weights=weights, + eps=1e-6, + ) + self.norm2 = FastLayerNorm.load( + prefix=f"{prefix}.norm2", + weights=weights, + eps=1e-6, + ) + self.mlp = Qwen2VLVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + ) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states_post_norm1, res = self.norm1(hidden_states) + hidden_states = hidden_states + self.attn( + hidden_states_post_norm1, cu_seqlens, rotary_pos_emb + ) + hidden_states_post_norm2, res = self.norm2(hidden_states) + hidden_states = hidden_states + self.mlp(hidden_states_post_norm2) + return hidden_states + + +class Qwen2VLPatchMerger(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + context_dim = 2560 + spatial_merge_size: int = 2 + self.hidden_size = 5120 # context_dim * (spatial_merge_size**2) + self.patch_merger_ln_q = FastLayerNorm.load( + prefix=f"{prefix}.ln_q", + weights=weights, + eps=1e-6, + ) + self.fc1 = TensorParallelColumnLinear.load( + prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( + prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True + ) + + def forward(self, hidden_states, grid_thw) -> torch.Tensor: + hidden_states, _ = self.patch_merger_ln_q(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.fc1(hidden_states) + hidden_states = F.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Qwen2VisionModel(nn.Module): + def __init__(self, *, prefix, config, weights): + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size] + self.patch_embedding = nn.Conv3d( + in_channels=config.in_chans, + out_channels=config.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False + ) + head_dim = config.embed_dim // config.num_heads + # TODO: replace with static positional embeddings once implemented + theta = 10000.0 + dim = head_dim // 2 + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + prefix=f"{prefix}.blocks.{i}", + config=config, + weights=weights, + ) + for i in range(config.depth) + ] + ) + self.merger = Qwen2VLPatchMerger( + prefix=f"{prefix}.merger", + config=config, + weights=weights, + ) + + self.temporal_patch_size = config.temporal_patch_size + self.spatial_patch_size = config.spatial_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.embed_dim + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + grid_thw: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # reshape the input tensor for processing + shape = ( + -1, + self.in_channels, + self.temporal_patch_size, + self.spatial_patch_size, + self.spatial_patch_size, + ) + pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype) + hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim) + # TODO: revisit to see if we can avoid some of these reshapes + + # find the position ids for the input tensor based on the grid_thw + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # apply the positional embeddings to the position ids + seq = torch.arange( + max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype) + + # create a cu_seqlens tensor to be used in the attention mask + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # iterately apply the blocks to the hidden states + for block in self.blocks: + hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) + + # apply the final patch merger to the hidden states + hidden_states = self.merger(hidden_states, grid_thw) + return hidden_states + + +class Qwen2VLForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = None + config.vision_config.speculator = config.speculator + self.hidden_size = config.hidden_size + self.vision_start_token_id = config.vision_start_token_id + self.image_token_id = config.image_token_id + self.video_token_id = config.video_token_id + self.spatial_merge_size = config.vision_config.spatial_merge_size + + self.visual = Qwen2VisionModel( + prefix="visual", config=config.vision_config, weights=weights + ) + self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor], + pixel_values: torch.FloatTensor = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + pixel_attention_mask=None, + image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + image_indices=None, + ): + + # make an attention_mask that is the same size as the input_ids + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + + inputs_embeds = self.text_model.embed_tokens(input_ids) + + # apply the visual model to the pixel values if they are provided + if pixel_values is not None and len(pixel_values) > 0: + if pixel_values is not None: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = ( + (input_ids == self.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + # input embeddings are masked with image embeddings + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # handle the position_ids taking the multimodal inputs into account + mrope_position_deltas = [] + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == self.vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + # determine the number of images and videos in the input + image_nums = (vision_tokens == self.image_token_id).sum() + video_nums = (vision_tokens == self.video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + # process each input based on it's token type and grid size + for _ in range(image_nums + video_nums): + if self.image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(self.image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if self.video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(self.video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // self.spatial_merge_size, + w.item() // self.spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + + # TODO: adjust model to accept 2D position_ids + outputs = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + true_max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + return outputs, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 4bbddcfb..1b8e7f88 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -67,6 +67,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens + elif config.model_type == "qwen2_vl": + return "" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -137,6 +139,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] + image_grid_thw: Optional[torch.Tensor] @classmethod @tracer.start_as_current_span("concatenate") @@ -145,6 +148,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @tracer.start_as_current_span("filter") @@ -153,6 +157,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @classmethod @@ -178,6 +183,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: + # TODO: REMOVE (this is for debugging purposes) + images = images[0][0].resize( + (images[0][0].width * 2, images[0][0].height * 2) + ) image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None @@ -237,10 +246,15 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.image_sizes = image_inputs["image_sizes"].to(device=device) else: batch.image_sizes = None + if "image_grid_thw" in image_inputs: + batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device) + else: + batch.image_grid_thw = None else: batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None + batch.image_grid_thw = None return batch @@ -381,8 +395,9 @@ class VlmCausalLM(FlashCausalLM): max_k=batch.max_current_length, ) logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, + # TODO: remove the unsqueeze(0) + input_ids=input_ids.unsqueeze(0), + position_ids=position_ids.unsqueeze(0), cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -394,6 +409,7 @@ class VlmCausalLM(FlashCausalLM): pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, image_sizes=batch.image_sizes, + image_grid_thw=batch.image_grid_thw, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -403,6 +419,8 @@ class VlmCausalLM(FlashCausalLM): batch.pixel_attention_mask = None if batch.image_sizes is not None: batch.image_sizes = None + if batch.image_grid_thw is not None: + batch.image_grid_thw = None return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph