multi-modality initial PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-19 23:27:27 -07:00
parent d5b78ba16f
commit f95aa42660
8 changed files with 1829 additions and 47 deletions

View File

@ -19,18 +19,10 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig, PhiMoEConfig,
) )
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.utils.adapter import ( from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
build_layer_weight_lookup, build_layer_weight_lookup,
@ -58,8 +50,8 @@ if ATTENTION == "paged":
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM, FlashDeepseekV2ForCausalLM,
DeepseekV2Config, DeepseekV2Config,
@ -101,12 +93,12 @@ try:
FlashPhiForCausalLM, FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import ( from text_generation_server.models.custom_modeling.flash_mllama import (
MllamaForConditionalGeneration, FlashMllamaForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.flash_llava_next import (
LlavaNextForConditionalGeneration, FlashLlavaNextForConditionalGeneration,
) )
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
@ -751,7 +743,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == QWEN2_VL: elif model_type == QWEN2_VL:
return VlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2VLForConditionalGeneration, model_class=Qwen2VLForConditionalGeneration,
revision=revision, revision=revision,
@ -764,7 +756,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif model_type == QWEN2_5_VL: elif model_type == QWEN2_5_VL:
return VlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration, model_class=Qwen2_5VLForConditionalGeneration,
revision=revision, revision=revision,
@ -779,10 +771,10 @@ def get_model(
processor_class=Qwen2_5_VLProcessor, processor_class=Qwen2_5_VLProcessor,
) )
elif model_type == MLLAMA: elif model_type == MLLAMA:
return MllamaCausalLM( return FlashMllamaCausalLM(
model_id=model_id, model_id=model_id,
model_class=MllamaForConditionalGeneration, model_class=FlashMllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch, batch_class=FlashMllamaCausalLMBatch,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
@ -792,7 +784,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif model_type == IDEFICS2: elif model_type == IDEFICS2:
return VlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Idefics2ForConditionalGeneration, model_class=Idefics2ForConditionalGeneration,
revision=revision, revision=revision,
@ -807,7 +799,7 @@ def get_model(
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
) )
elif model_type == IDEFICS3: elif model_type == IDEFICS3:
return VlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=Idefics3ForConditionalGeneration, model_class=Idefics3ForConditionalGeneration,
revision=revision, revision=revision,
@ -822,7 +814,7 @@ def get_model(
processor_kwargs={"size": {"longest_edge": 1456}}, processor_kwargs={"size": {"longest_edge": 1456}},
) )
elif model_type == PALIGEMMA: elif model_type == PALIGEMMA:
return VlmCausalLM( return FlashVlmCausalLM(
model_id=model_id, model_id=model_id,
model_class=PaliGemmaForConditionalGeneration, model_class=PaliGemmaForConditionalGeneration,
revision=revision, revision=revision,
@ -837,8 +829,8 @@ def get_model(
batch_class=PaliGemmaBatch, batch_class=PaliGemmaBatch,
) )
elif model_type == LLAVA_NEXT: elif model_type == LLAVA_NEXT:
return VlmCausalLM( return FlashVlmCausalLM(
model_class=LlavaNextForConditionalGeneration, model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
@ -847,6 +839,15 @@ def get_model(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
adapt_transformers_to_gaudi() adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1: if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)

View File

@ -503,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
# Skip first and last layers # Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1): for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers: if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.mllama import ( from text_generation_server.models.custom_modeling.flash_mllama import (
FlashLlamaCrossLayer, FlashLlamaCrossLayer,
) )

View File

@ -0,0 +1,290 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class FlashLlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,996 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(
batch_size, max_num_tiles * target_length, 1
)
attention_mask = (
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(
num_vision_tokens, dim=3
)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value)
.any(dim=-1)
.type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
)
self.gate_ffn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
config=config,
weights=weights,
is_gated=is_gated,
)
for i in range(num_layers)
]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
)
hidden_states = layer_outputs
encoder_states.append(hidden_states)
return hidden_states, encoder_states
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.embedding = TensorParallelEmbedding(
prefix=f"{prefix}.embedding", weights=weights
)
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
# Always gated.
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
# position embedding
embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
)
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
self.tile_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.tile_embedding", weights=weights
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
hidden_state = hidden_state + self.gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
prefix=f"{prefix}.gated_positional_embedding",
config=config,
weights=weights,
)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.pre_tile_positional_embedding",
config=config,
weights=weights,
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.post_tile_positional_embedding",
config=config,
weights=weights,
)
## layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
## encoders
self.transformer = MllamaVisionEncoder(
prefix=f"{prefix}.transformer",
config=config,
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(pixel_values)
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None:
attention_mask = attention_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
intermediate_hidden_states = [
hidden_state
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
if idx in self.intermediate_layers_indices
]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(
hidden_state, attention_mask=attention_mask
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.q_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(-1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
) = cross_attention_states
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# execute sdpa
query_states = query_states.unsqueeze(0).transpose(1, 2)
key_states = key_states.unsqueeze(0).transpose(1, 2)
value_states = value_states.unsqueeze(0).transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.cross_attn_attn_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states, # [ IB, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
@classmethod
def load(cls, *, prefix, weights, eps):
weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
return cls(weight=weight, eps=eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class FlashMllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = cu_seqlen_q[-1].item()
max_k = seqlen_k
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = seqlen_q
max_k = seqlen_k
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
)
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
return outputs

View File

@ -0,0 +1,480 @@
import torch
from PIL import Image
from io import BytesIO
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.models.globals import PREFIX_CACHING
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
image_seq_len: int,
image_rows: int,
image_cols: int,
fake_token_around_image: str,
image_token: str,
global_img_token: str,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
if config.model_type == "idefics2":
image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = _prompt_split_image(
image_seq_len=image_seq_len,
image_rows=n_rows,
image_cols=n_cols,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
log_master(
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
def get_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
full_text = ""
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
class FlashVlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=FlashVlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs,
)
@property
def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward(
self,
batch: FlashVlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
)
batch.position_ids = position_ids
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits

View File

@ -1,15 +1,21 @@
from io import BytesIO
from PIL import Image
import torch import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -18,7 +24,7 @@ tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch): class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
image_indices: List[int] = 42 image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None
@ -154,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
config, config,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "VlmCausalLMBatch": ) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config pb.requests, tokenizer, processor, config
) )
@ -163,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1 max=config.text_config.vocab_size - 1
) )
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None: if image_inputs is not None:
@ -183,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
return batch return batch
class MllamaCausalLM(VlmCausalLM): class FlashMllamaCausalLM(FlashVlmCausalLM):
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: FlashMllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None, adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
@ -198,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
@ -217,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -240,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
@ -250,14 +263,10 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
input_lengths = input_lengths + prefix_lens_tensor
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
) )
if batch.pixel_values is not None: if batch.pixel_values is not None:

View File

@ -4,8 +4,8 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable from typing import Iterable
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.flash_vlm_causal_lm import (
VlmCausalLMBatch, FlashVlmCausalLMBatch,
image_text_replacement, image_text_replacement,
) )
@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch): class PaliGemmaBatch(FlashVlmCausalLMBatch):
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config cls, requests: Iterable[Request], tokenizer, processor, config

View File

@ -25,15 +25,21 @@ from text_generation_server.utils.tokens import make_tokenizer_optional
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch, VlmCausalLMBatch,
) )
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
VLM_BATCH_TYPES = { VLM_BATCH_TYPES = {
PaliGemmaBatch, PaliGemmaBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
FlashVlmCausalLMBatch,
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
FlashMllamaCausalLMBatch,
} }
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.