mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
multi-modality initial PR
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
d5b78ba16f
commit
f95aa42660
@ -19,18 +19,10 @@ from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
|
||||
PhiMoEConfig,
|
||||
)
|
||||
|
||||
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
build_layer_weight_lookup,
|
||||
@ -58,8 +50,8 @@ if ATTENTION == "paged":
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
FlashDeepseekV2ForCausalLM,
|
||||
DeepseekV2Config,
|
||||
@ -101,12 +93,12 @@ try:
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||
FlashMllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
from text_generation_server.models.custom_modeling.flash_llava_next import (
|
||||
FlashLlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
@ -751,7 +743,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == QWEN2_VL:
|
||||
return VlmCausalLM(
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2VLForConditionalGeneration,
|
||||
revision=revision,
|
||||
@ -764,7 +756,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == QWEN2_5_VL:
|
||||
return VlmCausalLM(
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2_5VLForConditionalGeneration,
|
||||
revision=revision,
|
||||
@ -779,10 +771,10 @@ def get_model(
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
)
|
||||
elif model_type == MLLAMA:
|
||||
return MllamaCausalLM(
|
||||
return FlashMllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
model_class=FlashMllamaForConditionalGeneration,
|
||||
batch_class=FlashMllamaCausalLMBatch,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
@ -792,7 +784,7 @@ def get_model(
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == IDEFICS2:
|
||||
return VlmCausalLM(
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics2ForConditionalGeneration,
|
||||
revision=revision,
|
||||
@ -807,7 +799,7 @@ def get_model(
|
||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||
)
|
||||
elif model_type == IDEFICS3:
|
||||
return VlmCausalLM(
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics3ForConditionalGeneration,
|
||||
revision=revision,
|
||||
@ -822,7 +814,7 @@ def get_model(
|
||||
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||
)
|
||||
elif model_type == PALIGEMMA:
|
||||
return VlmCausalLM(
|
||||
return FlashVlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PaliGemmaForConditionalGeneration,
|
||||
revision=revision,
|
||||
@ -837,8 +829,8 @@ def get_model(
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
elif model_type == LLAVA_NEXT:
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
return FlashVlmCausalLM(
|
||||
model_class=FlashLlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
@ -847,6 +839,15 @@ def get_model(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
|
@ -503,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
# Skip first and last layers
|
||||
for layer_id in range(1, config.num_hidden_layers - 1):
|
||||
if layer_id in self.cross_attention_layers:
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
from text_generation_server.models.custom_modeling.flash_mllama import (
|
||||
FlashLlamaCrossLayer,
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,290 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashLlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
try:
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -0,0 +1,996 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
|
||||
|
||||
def _prepare_aspect_ratio_attention_mask(
|
||||
aspect_ratio_mask: torch.Tensor,
|
||||
num_patches: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
# Expand aspect ratio mask to target_length
|
||||
batch_size, max_num_tiles = aspect_ratio_mask.shape
|
||||
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
|
||||
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
|
||||
|
||||
# Mask padding patches
|
||||
pad_patches = target_length - num_patches
|
||||
attention_mask[:, :, -pad_patches:] = 0
|
||||
|
||||
# Invert the mask (0 -> 1, 1 -> 0)
|
||||
attention_mask = 1 - attention_mask
|
||||
|
||||
# Reshape to 2D and create 4D attention mask
|
||||
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
|
||||
attention_mask = attention_mask.reshape(
|
||||
batch_size, max_num_tiles * target_length, 1
|
||||
)
|
||||
attention_mask = (
|
||||
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
|
||||
)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(
|
||||
target_length, device=device
|
||||
) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = (
|
||||
causal_mask.clone()
|
||||
) # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = (
|
||||
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||
:, :, :, :mask_length
|
||||
].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def _prepare_cross_attention_mask(
|
||||
cross_attention_mask: torch.Tensor,
|
||||
num_vision_tokens: int,
|
||||
dtype: str,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# reshape so it can be used by attn module
|
||||
batch_size, text_total_length, *_ = cross_attention_mask.shape
|
||||
cross_attention_mask = cross_attention_mask.repeat_interleave(
|
||||
num_vision_tokens, dim=3
|
||||
)
|
||||
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
|
||||
cross_attention_mask = cross_attention_mask.unsqueeze(1)
|
||||
|
||||
# invert the mask
|
||||
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
|
||||
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
|
||||
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
|
||||
# last dimension contains negative infinity values, otherwise it's 1
|
||||
negative_inf_value = torch.finfo(dtype).min
|
||||
full_text_row_masked_out_mask = (
|
||||
(cross_attention_mask != negative_inf_value)
|
||||
.any(dim=-1)
|
||||
.type_as(cross_attention_mask)[..., None]
|
||||
)
|
||||
cross_attention_mask *= full_text_row_masked_out_mask
|
||||
|
||||
return cross_attention_mask, full_text_row_masked_out_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||||
class MllamaVisionMLP(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MllamaVisionSdpaAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.head_dim = config.hidden_size // config.attention_heads
|
||||
self.num_heads = config.attention_heads // weights.process_group.size()
|
||||
|
||||
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_state)
|
||||
query, key, value = qkv.split(
|
||||
[
|
||||
self.head_dim * self.num_heads,
|
||||
self.head_dim * self.num_heads,
|
||||
self.head_dim * self.num_heads,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
batch_size, q_seq_len, _ = query.shape
|
||||
_, kv_seq_len, _ = key.shape
|
||||
|
||||
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||||
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
||||
output = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MllamaVisionEncoderLayer(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights, is_gated: bool):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.attention_heads
|
||||
self.is_gated = is_gated
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = MllamaVisionSdpaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = MllamaVisionMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
|
||||
)
|
||||
|
||||
# there used to be an if else here, no code path
|
||||
if is_gated:
|
||||
self.gate_attn = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||||
)
|
||||
self.gate_ffn = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = hidden_state
|
||||
hidden_state = self.input_layernorm(hidden_state)
|
||||
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
||||
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
||||
hidden_state = residual + gate_attn * hidden_state
|
||||
|
||||
# Feed forward
|
||||
residual = hidden_state
|
||||
hidden_state = self.post_attention_layernorm(hidden_state)
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
||||
hidden_state = residual + gate_ffn * hidden_state
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MllamaVisionEncoder(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = [
|
||||
MllamaVisionEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
is_gated=is_gated,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
encoder_states = [hidden_states]
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs
|
||||
encoder_states.append(hidden_states)
|
||||
|
||||
return hidden_states, encoder_states
|
||||
|
||||
|
||||
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.max_num_tiles = config.max_num_tiles
|
||||
self.hidden_size = config.hidden_size
|
||||
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||
|
||||
self.embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embedding", weights=weights
|
||||
)
|
||||
self.gate = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
embeddings = self.embedding(aspect_ratio_ids)
|
||||
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
||||
|
||||
# Always gated.
|
||||
embeddings = embeddings * self.gate.tanh()
|
||||
|
||||
hidden_state = hidden_state + embeddings
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.max_num_tiles = config.max_num_tiles
|
||||
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
||||
self.hidden_size = config.hidden_size
|
||||
self.scale = config.hidden_size**-0.5
|
||||
|
||||
self.gate = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||
)
|
||||
|
||||
# position embedding
|
||||
embedding = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
|
||||
)
|
||||
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
|
||||
self.tile_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.tile_embedding", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# position embeddings
|
||||
hidden_state = hidden_state + self.gated_position_embedding.view(
|
||||
1, 1, self.num_patches, self.hidden_size
|
||||
)
|
||||
|
||||
# precomputed tile position embeddings
|
||||
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
||||
batch_size = hidden_state.shape[0]
|
||||
tile_position_embedding = tile_position_embedding.reshape(
|
||||
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
||||
)
|
||||
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
||||
hidden_state = hidden_state + gated_tile_position_embedding
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MllamaVisionModel(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
self.max_num_tiles = config.max_num_tiles
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_channels = config.num_channels
|
||||
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||
self.scale = config.hidden_size**-0.5
|
||||
self.dtype = weights.dtype
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.hidden_size,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
|
||||
)
|
||||
|
||||
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
|
||||
prefix=f"{prefix}.gated_positional_embedding",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||
prefix=f"{prefix}.pre_tile_positional_embedding",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||
prefix=f"{prefix}.post_tile_positional_embedding",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
## layer norms
|
||||
self.layernorm_pre = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layernorm_pre",
|
||||
weights=weights,
|
||||
# torch default
|
||||
eps=1e-05,
|
||||
)
|
||||
self.layernorm_post = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layernorm_post",
|
||||
weights=weights,
|
||||
# torch default
|
||||
eps=1e-05,
|
||||
)
|
||||
|
||||
## encoders
|
||||
self.transformer = MllamaVisionEncoder(
|
||||
prefix=f"{prefix}.transformer",
|
||||
config=config,
|
||||
weights=weights,
|
||||
is_gated=False,
|
||||
num_layers=config.num_hidden_layers,
|
||||
)
|
||||
self.global_transformer = MllamaVisionEncoder(
|
||||
prefix=f"{prefix}.global_transformer",
|
||||
config=config,
|
||||
weights=weights,
|
||||
is_gated=True,
|
||||
num_layers=config.num_global_layers,
|
||||
)
|
||||
|
||||
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, hidden_size = hidden_state.shape
|
||||
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||
return hidden_state
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
aspect_ratio_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(
|
||||
batch_size,
|
||||
num_concurrent_media,
|
||||
num_tiles,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
) = pixel_values.shape
|
||||
|
||||
pixel_values = pixel_values.reshape(
|
||||
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||
)
|
||||
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
||||
batch_size * num_concurrent_media, -1
|
||||
)
|
||||
|
||||
# patch embedding
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
# tile embeddings
|
||||
_, num_patches, dim = hidden_state.shape
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media, num_tiles, -1, dim
|
||||
)
|
||||
hidden_state = self.pre_tile_positional_embedding(
|
||||
hidden_state, aspect_ratio_ids
|
||||
)
|
||||
|
||||
# apply cls token
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
||||
)
|
||||
hidden_state = self.apply_class_embedding(hidden_state)
|
||||
num_patches += 1
|
||||
|
||||
# apply position embeddings
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
||||
)
|
||||
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
||||
|
||||
# apply encoder
|
||||
hidden_state = self.layernorm_pre(hidden_state)
|
||||
|
||||
# Compute the number of tokens to pad
|
||||
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
||||
# Compute padding tuple for pad function
|
||||
padding = (
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
num_padding_patches,
|
||||
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
||||
# Pad the tensor
|
||||
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
||||
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.reshape(
|
||||
batch_size * num_concurrent_media, -1
|
||||
)
|
||||
attention_mask = _prepare_aspect_ratio_attention_mask(
|
||||
aspect_ratio_mask=attention_mask,
|
||||
num_patches=self.num_patches,
|
||||
target_length=hidden_state.shape[2],
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
||||
hidden_state, all_intermediate_hidden_states = self.transformer(
|
||||
hidden_state,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
intermediate_hidden_states = [
|
||||
hidden_state
|
||||
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
|
||||
if idx in self.intermediate_layers_indices
|
||||
]
|
||||
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
||||
|
||||
# apply global encoder
|
||||
hidden_state = self.layernorm_post(hidden_state)
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media,
|
||||
num_tiles,
|
||||
num_patches + num_padding_patches,
|
||||
dim,
|
||||
)
|
||||
hidden_state = self.post_tile_positional_embedding(
|
||||
hidden_state, aspect_ratio_ids
|
||||
)
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media,
|
||||
num_tiles * (num_patches + num_padding_patches),
|
||||
dim,
|
||||
)
|
||||
hidden_state, _ = self.global_transformer(
|
||||
hidden_state, attention_mask=attention_mask
|
||||
)
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size * num_concurrent_media,
|
||||
num_tiles,
|
||||
num_patches + num_padding_patches,
|
||||
dim,
|
||||
)
|
||||
hidden_state = hidden_state[:, :, :slice_index]
|
||||
|
||||
# adding intermediate layer outputs
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
||||
)
|
||||
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||
batch_size * num_concurrent_media,
|
||||
num_tiles,
|
||||
num_patches + num_padding_patches,
|
||||
-1,
|
||||
)
|
||||
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
||||
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
||||
)
|
||||
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class MllamaTextCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, *, prefix, config, weights, layer_idx):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.num_key_value_heads
|
||||
self.dropout = config.dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
self.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.v_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.q_norm = MllamaTextRMSNorm.load(
|
||||
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.k_norm = MllamaTextRMSNorm.load(
|
||||
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
# past_key_value=None,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
# hidden_states = hidden_states.unsqueeze(0)
|
||||
# bsz, q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
(
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
indices,
|
||||
) = cross_attention_states
|
||||
|
||||
key_states = self.k_proj(cross_attention_states)
|
||||
value_states = self.v_proj(cross_attention_states)
|
||||
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||
|
||||
causal = False
|
||||
# logger.info(
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
# execute sdpa
|
||||
query_states = query_states.unsqueeze(0).transpose(1, 2)
|
||||
key_states = key_states.unsqueeze(0).transpose(1, 2)
|
||||
value_states = value_states.unsqueeze(0).transpose(1, 2)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||||
class MllamaTextMLP(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape
|
||||
gate_up_states = self.gate_up_proj(x)
|
||||
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
|
||||
result = self.down_proj(
|
||||
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class FlashLlamaCrossLayer(torch.nn.Module):
|
||||
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
||||
|
||||
def __init__(self, *, prefix, config, weights, index) -> None:
|
||||
layer_idx = index
|
||||
super().__init__()
|
||||
self.cross_attn = MllamaTextCrossAttention(
|
||||
prefix=f"{prefix}.cross_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.input_layernorm = MllamaTextRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.cross_attn_attn_gate = torch.nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
|
||||
)
|
||||
|
||||
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.post_attention_layernorm = MllamaTextRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||
)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
cross_attention_states, # [ IB, ...]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cross_attention_states is None:
|
||||
return hidden_states, residual
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
|
||||
indices = cross_attention_states[-1]
|
||||
out_hidden_states = hidden_states[:]
|
||||
if len(indices) > 0:
|
||||
assert max(indices) < hidden_states.shape[0]
|
||||
hidden_states = hidden_states[indices]
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
# attention_mask=cross_attention_mask,
|
||||
cross_attention_states=cross_attention_states,
|
||||
)
|
||||
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||
|
||||
out_hidden_states[indices] = hidden_states
|
||||
hidden_states = out_hidden_states
|
||||
|
||||
return hidden_states, None
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||
class MllamaTextRMSNorm(nn.Module):
|
||||
def __init__(self, weight, eps):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@classmethod
|
||||
def load(cls, *, prefix, weights, eps):
|
||||
weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||
)
|
||||
return cls(weight=weight, eps=eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
config.text_config._attn_implementation = "sdpa"
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.vision_model = MllamaVisionModel(
|
||||
prefix="vision_model", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.multi_modal_projector = FastLinear.load(
|
||||
prefix="multi_modal_projector", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.text_model = FlashLlamaForCausalLM(
|
||||
prefix="language_model", config=config.text_config, weights=weights
|
||||
)
|
||||
self.config = config
|
||||
self.dtype = weights.dtype
|
||||
self.device = weights.device
|
||||
|
||||
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# logger.info(f"PIxel values {pixel_values.shape}")
|
||||
batch_size = pixel_values.shape[0]
|
||||
vision_states = self.vision_model(
|
||||
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||
)
|
||||
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
||||
-1, vision_states.shape[-2], self.hidden_size
|
||||
)
|
||||
_, _, h = cross_attention_states.shape
|
||||
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
|
||||
# logger.info(f"cross {cross_attention_states.shape}")
|
||||
return cross_attention_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
if cross_attention_states is not None:
|
||||
seqlen_q = len(image_indices)
|
||||
n_images = cross_attention_states.shape[0]
|
||||
seqlen_k = cross_attention_states.shape[1]
|
||||
device = cross_attention_states.device
|
||||
if cu_seqlen_prefill is not None:
|
||||
offset = 0
|
||||
cu_q = []
|
||||
indices = []
|
||||
for index in image_indices:
|
||||
cu_q.append(offset)
|
||||
length = seqlen.input_lengths[index].item()
|
||||
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||
offset += length
|
||||
cu_q.append(offset)
|
||||
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||
|
||||
assert max(indices) < input_ids.shape[0]
|
||||
|
||||
cu_seqlen_k = (
|
||||
torch.arange(
|
||||
n_images + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
max_q = cu_seqlen_q[-1].item()
|
||||
max_k = seqlen_k
|
||||
else:
|
||||
cu_seqlen_q = torch.arange(
|
||||
seqlen_q + 1, device=device, dtype=torch.int32
|
||||
)
|
||||
seqlen_k = cross_attention_states.shape[1]
|
||||
n_images = cross_attention_states.shape[0]
|
||||
cu_seqlen_k = (
|
||||
torch.arange(
|
||||
n_images + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
max_q = seqlen_q
|
||||
max_k = seqlen_k
|
||||
indices = image_indices[:]
|
||||
|
||||
cross_attention_states = (
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
indices,
|
||||
)
|
||||
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
cross_attention_states=cross_attention_states,
|
||||
)
|
||||
|
||||
return outputs
|
@ -0,0 +1,480 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||
|
||||
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||
|
||||
|
||||
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||
def _prompt_split_image(
|
||||
*,
|
||||
image_seq_len: int,
|
||||
image_rows: int,
|
||||
image_cols: int,
|
||||
fake_token_around_image: str,
|
||||
image_token: str,
|
||||
global_img_token: str,
|
||||
):
|
||||
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||
text_split_images = ""
|
||||
for n_h in range(image_rows):
|
||||
for n_w in range(image_cols):
|
||||
text_split_images += (
|
||||
f"{fake_token_around_image}"
|
||||
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
)
|
||||
text_split_images += "\n"
|
||||
|
||||
text_split_images += (
|
||||
f"\n{fake_token_around_image}"
|
||||
+ f"{global_img_token}"
|
||||
+ f"{image_token}" * image_seq_len
|
||||
+ f"{fake_token_around_image}"
|
||||
)
|
||||
return text_split_images
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
image_seq_len = 64
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
if processor.image_processor.do_image_splitting:
|
||||
image_str *= 5
|
||||
return image_str
|
||||
if config.model_type == "idefics3":
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
n_cols = image_input["cols"][0][image_id]
|
||||
image_seq_len = int(
|
||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||
/ (config.scale_factor**2)
|
||||
)
|
||||
image_str = _prompt_split_image(
|
||||
image_seq_len=image_seq_len,
|
||||
image_rows=n_rows,
|
||||
image_cols=n_cols,
|
||||
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||
)
|
||||
return image_str
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||
)
|
||||
return "<image>" * num_features
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
elif config.model_type == "qwen2_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
elif config.model_type == "qwen2_5_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
elif config.model_type == "gemma3":
|
||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||
# and calculating the number of image tokens
|
||||
num_pads = 256
|
||||
padding = "<image_soft_token>" * num_pads
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def image_text_replacement_fixup(config, text: str) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = original_width / original_height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||
image_grid_pinpoints = config.image_grid_pinpoints
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
|
||||
assert image_size % patch_size == 0
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
unpadded_features, newline_features = get_unpadded_features(
|
||||
height, width, npatches, num_patch_height, num_patch_width
|
||||
)
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
# Process images first. We need all of them so that the processor
|
||||
# can make the image splits the same size. And we need the final
|
||||
# sizes to insert correct number of image tokens.
|
||||
images = []
|
||||
for r in requests:
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
pass
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
||||
# default warmup image is 20x20
|
||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if image.width <= 20:
|
||||
w = image.width * 2
|
||||
h = image.height * 2
|
||||
image = image.resize((w, h))
|
||||
|
||||
if config.model_type == "llava_next":
|
||||
images.append(image)
|
||||
elif config.model_type == "gemma3":
|
||||
images.append(image)
|
||||
else:
|
||||
images.append([image])
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
image_inputs = processor.image_processor(
|
||||
images, return_tensors="pt", **kwargs
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
image_id = 0
|
||||
for r in requests:
|
||||
full_text = ""
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
)["input_ids"]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor,
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashVlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||
device=device
|
||||
)
|
||||
else:
|
||||
batch.pixel_attention_mask = None
|
||||
if "image_sizes" in image_inputs:
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
if "image_grid_thw" in image_inputs:
|
||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
|
||||
class FlashVlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
*,
|
||||
processor_class=AutoProcessor,
|
||||
processor_kwargs=None,
|
||||
batch_class=FlashVlmCausalLMBatch,
|
||||
revision,
|
||||
trust_remote_code: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if PREFIX_CACHING:
|
||||
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
self.processor = processor_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# FIXME: VLM do not work with context chunking yet
|
||||
support_chunking=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
|
||||
return self.batch_class
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: FlashVlmCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if position_ids.dim() == 1 and batch.prefilling:
|
||||
position_ids = self.model.get_position_ids(
|
||||
input_ids, batch.image_grid_thw
|
||||
)
|
||||
batch.position_ids = position_ids
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
image_grid_thw=batch.image_grid_thw,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
return logits, speculative_logits
|
@ -1,15 +1,21 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Optional, Tuple, List, Dict
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
FlashVlmCausalLM,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
@ -18,7 +24,7 @@ tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
image_indices: List[int] = 42
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||
@ -154,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
) -> "FlashVlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
@ -163,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||
max=config.text_config.vocab_size - 1
|
||||
)
|
||||
if isinstance(batch.input_ids, list):
|
||||
if len(batch) > 1:
|
||||
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = batch.input_ids[0]
|
||||
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||
|
||||
if image_inputs is not None:
|
||||
@ -183,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
return batch
|
||||
|
||||
|
||||
class MllamaCausalLM(VlmCausalLM):
|
||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
batch: FlashMllamaCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
@ -198,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
@ -217,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -240,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
max_s = batch.max_seqlen
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
@ -250,14 +263,10 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
)
|
||||
|
||||
if batch.pixel_values is not None:
|
||||
|
@ -4,8 +4,8 @@ import torch
|
||||
import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Iterable
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
image_text_replacement,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
class PaliGemmaBatch(FlashVlmCausalLMBatch):
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||
|
@ -25,15 +25,21 @@ from text_generation_server.utils.tokens import make_tokenizer_optional
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_vlm_causal_lm import (
|
||||
FlashVlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
VLM_BATCH_TYPES = {
|
||||
PaliGemmaBatch,
|
||||
VlmCausalLMBatch,
|
||||
FlashVlmCausalLMBatch,
|
||||
IdeficsCausalLMBatch,
|
||||
FlashMllamaCausalLMBatch,
|
||||
}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
|
Loading…
Reference in New Issue
Block a user