mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
1234 lines
47 KiB
Python
1234 lines
47 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""PyTorch Mllama model."""
|
||
|
|
||
|
from typing import List, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
import math
|
||
|
|
||
|
from transformers.activations import ACT2FN
|
||
|
import torch.nn.functional as F
|
||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||
|
load_text_model,
|
||
|
)
|
||
|
from text_generation_server.layers.attention import Seqlen
|
||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||
|
|
||
|
from text_generation_server.layers.layernorm import (
|
||
|
FastRMSNorm,
|
||
|
)
|
||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||
|
from text_generation_server.layers import (
|
||
|
TensorParallelColumnLinear,
|
||
|
TensorParallelEmbedding,
|
||
|
TensorParallelRowLinear,
|
||
|
SpeculativeHead,
|
||
|
FastLinear,
|
||
|
)
|
||
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||
|
class MllamaVisionMLP(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||
|
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||
|
)
|
||
|
self.fc2 = TensorParallelRowLinear.load(
|
||
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||
|
)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.fc1(hidden_states)
|
||
|
hidden_states = self.activation_fn(hidden_states)
|
||
|
hidden_states = self.fc2(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class MllamaVisionSdpaAttention(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
|
||
|
self.embed_dim = config.hidden_size
|
||
|
self.num_heads = config.attention_heads
|
||
|
self.head_dim = config.hidden_size // config.attention_heads
|
||
|
|
||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||
|
config,
|
||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||
|
dim=0,
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
self.o_proj = TensorParallelRowLinear.load(
|
||
|
config,
|
||
|
prefix=f"{prefix}.o_proj",
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_state: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
) -> torch.Tensor:
|
||
|
qkv = self.qkv_proj(hidden_state)
|
||
|
query, key, value = qkv.split(
|
||
|
[
|
||
|
self.head_size * self.num_heads,
|
||
|
self.head_size * self.num_heads,
|
||
|
self.head_size * self.num_heads,
|
||
|
],
|
||
|
dim=1,
|
||
|
)
|
||
|
|
||
|
batch_size, q_seq_len, _ = query.shape
|
||
|
_, kv_seq_len, _ = key.shape
|
||
|
|
||
|
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||
|
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||
|
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||
|
|
||
|
query = query.transpose(1, 2)
|
||
|
key = key.transpose(1, 2)
|
||
|
value = value.transpose(1, 2)
|
||
|
|
||
|
attn_output = F.scaled_dot_product_attention(
|
||
|
query, key, value, attn_mask=attention_mask
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||
|
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||
|
|
||
|
output = self.o_proj(attn_output)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class MllamaVisionEncoderLayer(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights, is_gated: bool):
|
||
|
super().__init__()
|
||
|
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.num_attention_heads = config.attention_heads
|
||
|
self.is_gated = is_gated
|
||
|
self.intermediate_size = config.intermediate_size
|
||
|
|
||
|
self.self_attn = MllamaVisionSdpaAttention(
|
||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||
|
)
|
||
|
self.mlp = MllamaVisionMLP(
|
||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||
|
)
|
||
|
|
||
|
self.input_layernorm = nn.LayerNorm.load(
|
||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
|
||
|
)
|
||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||
|
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
|
||
|
)
|
||
|
|
||
|
# there used to be an if else here, no code path
|
||
|
if is_gated:
|
||
|
self.gate_attn = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||
|
)
|
||
|
self.gate_ffn = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_state: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
):
|
||
|
# Self Attention
|
||
|
residual = hidden_state
|
||
|
hidden_state = self.input_layernorm(hidden_state)
|
||
|
hidden_state, attn_weights = self.self_attn(
|
||
|
hidden_state, attention_mask=attention_mask
|
||
|
)
|
||
|
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
||
|
hidden_state = residual + gate_attn * hidden_state
|
||
|
|
||
|
# Feed forward
|
||
|
residual = hidden_state
|
||
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
||
|
hidden_state = self.mlp(hidden_state)
|
||
|
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
||
|
hidden_state = residual + gate_ffn * hidden_state
|
||
|
return hidden_state
|
||
|
|
||
|
|
||
|
class MllamaVisionEncoder(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.layers = [
|
||
|
MllamaVisionEncoderLayer(
|
||
|
prefix=f"{prefix}.layers.{i}",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
is_gated=is_gated,
|
||
|
)
|
||
|
for i in range(num_layers)
|
||
|
]
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
):
|
||
|
for encoder_layer in self.layers:
|
||
|
layer_outputs = encoder_layer(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.max_num_tiles = config.max_num_tiles
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||
|
|
||
|
self.embedding = TensorParallelEmbedding(
|
||
|
prefix=f"{prefix}.embedding", weights=weights
|
||
|
)
|
||
|
self.gate = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||
|
) -> torch.Tensor:
|
||
|
embeddings = self.embedding(aspect_ratio_ids)
|
||
|
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
||
|
|
||
|
# Always gated.
|
||
|
embeddings = embeddings * self.gate.tanh()
|
||
|
|
||
|
hidden_state = hidden_state + embeddings
|
||
|
return hidden_state
|
||
|
|
||
|
|
||
|
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.max_num_tiles = config.max_num_tiles
|
||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||
|
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.scale = config.hidden_size**-0.5
|
||
|
|
||
|
self.gate = nn.Parameter(torch.zeros(1))
|
||
|
|
||
|
# position embedding
|
||
|
self.embedding = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
|
||
|
)
|
||
|
self.tile_embedding = TensorParallelEmbedding(
|
||
|
prefix=f"{prefix}.tile_embedding", weights=weights
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||
|
) -> torch.Tensor:
|
||
|
# position embeddings
|
||
|
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
|
||
|
hidden_state = hidden_state + gated_position_embedding.view(
|
||
|
1, 1, self.num_patches, self.hidden_size
|
||
|
)
|
||
|
|
||
|
# precomputed tile position embeddings
|
||
|
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
||
|
batch_size = hidden_state.shape[0]
|
||
|
tile_position_embedding = tile_position_embedding.reshape(
|
||
|
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
||
|
)
|
||
|
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
||
|
hidden_state = hidden_state + gated_tile_position_embedding
|
||
|
|
||
|
return hidden_state
|
||
|
|
||
|
|
||
|
class MllamaVisionModel(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.image_size = config.image_size
|
||
|
self.patch_size = config.patch_size
|
||
|
self.max_num_tiles = config.max_num_tiles
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.in_channels = config.in_channels
|
||
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||
|
|
||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||
|
self.scale = config.hidden_size**-0.5
|
||
|
|
||
|
self.patch_embedding = nn.Conv2d(
|
||
|
in_channels=config.in_channels,
|
||
|
out_channels=self.hidden_size,
|
||
|
kernel_size=self.patch_size,
|
||
|
stride=self.patch_size,
|
||
|
padding="valid",
|
||
|
bias=False,
|
||
|
)
|
||
|
self.patch_embedding.weight = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
self.class_embedding = nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
|
||
|
prefix=f"{prefix}.gated_positional_embedding",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
|
||
|
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||
|
prefix=f"{prefix}.pre_tile_positional_embedding",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||
|
prefix=f"{prefix}.post_tile_positional_embedding",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
|
||
|
## layer norms
|
||
|
self.layernorm_pre = nn.LayerNorm.load(
|
||
|
prefix=f"{prefix}.layernorm_pre",
|
||
|
weights=weights,
|
||
|
# torch default
|
||
|
eps=1e-05,
|
||
|
)
|
||
|
self.layernorm_post = nn.LayerNorm.load(
|
||
|
prefix=f"{prefix}.layernorm_post",
|
||
|
weights=weights,
|
||
|
# torch default
|
||
|
eps=1e-05,
|
||
|
)
|
||
|
|
||
|
## encoders
|
||
|
self.transformer = MllamaVisionEncoder(
|
||
|
prefix=f"{prefix}.transformer",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
is_gated=False,
|
||
|
num_layers=config.num_hidden_layers,
|
||
|
)
|
||
|
self.global_transformer = MllamaVisionEncoder(
|
||
|
prefix=f"{prefix}.global_transformer",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
is_gated=True,
|
||
|
num_layers=config.num_global_layers,
|
||
|
)
|
||
|
|
||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||
|
batch_size, _, hidden_size = hidden_state.shape
|
||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||
|
return hidden_state
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.Tensor,
|
||
|
aspect_ratio_ids: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
) -> torch.Tensor:
|
||
|
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
||
|
pixel_values.shape
|
||
|
)
|
||
|
|
||
|
pixel_values = pixel_values.reshape(
|
||
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||
|
)
|
||
|
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
||
|
batch_size * num_concurrent_media, -1
|
||
|
)
|
||
|
|
||
|
# patch embedding
|
||
|
patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device))
|
||
|
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
|
||
|
|
||
|
# tile embeddings
|
||
|
_, num_patches, dim = hidden_state.shape
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media, num_tiles, -1, dim
|
||
|
)
|
||
|
hidden_state = self.pre_tile_positional_embedding(
|
||
|
hidden_state, aspect_ratio_ids
|
||
|
)
|
||
|
|
||
|
# apply cls token
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
||
|
)
|
||
|
hidden_state = self.apply_class_embedding(hidden_state)
|
||
|
num_patches += 1
|
||
|
|
||
|
# apply position embeddings
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
||
|
)
|
||
|
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
||
|
|
||
|
# apply encoder
|
||
|
hidden_state = self.layernorm_pre(hidden_state)
|
||
|
|
||
|
# Compute the number of tokens to pad
|
||
|
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
||
|
# Compute padding tuple for pad function
|
||
|
padding = (
|
||
|
0,
|
||
|
0,
|
||
|
0,
|
||
|
num_padding_patches,
|
||
|
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
||
|
# Pad the tensor
|
||
|
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
||
|
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
attention_mask = attention_mask.reshape(
|
||
|
batch_size * num_concurrent_media, -1
|
||
|
)
|
||
|
attention_mask = _prepare_aspect_ratio_attention_mask(
|
||
|
aspect_ratio_mask=attention_mask,
|
||
|
num_patches=self.num_patches,
|
||
|
target_length=hidden_state.shape[2],
|
||
|
dtype=self.dtype,
|
||
|
)
|
||
|
|
||
|
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
||
|
output = self.transformer(
|
||
|
hidden_state,
|
||
|
attention_mask=attention_mask,
|
||
|
output_hidden_states=True,
|
||
|
)
|
||
|
hidden_state, all_intermediate_hidden_states = output[0], output[1]
|
||
|
intermediate_hidden_states = [
|
||
|
hidden_state
|
||
|
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
|
||
|
if idx in self.intermediate_layers_indices
|
||
|
]
|
||
|
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
||
|
|
||
|
# apply global encoder
|
||
|
hidden_state = self.layernorm_post(hidden_state)
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media,
|
||
|
num_tiles,
|
||
|
num_patches + num_padding_patches,
|
||
|
dim,
|
||
|
)
|
||
|
hidden_state = self.post_tile_positional_embedding(
|
||
|
hidden_state, aspect_ratio_ids
|
||
|
)
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media,
|
||
|
num_tiles * (num_patches + num_padding_patches),
|
||
|
dim,
|
||
|
)
|
||
|
hidden_state = self.global_transformer(
|
||
|
hidden_state, attention_mask=attention_mask
|
||
|
)[0]
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size * num_concurrent_media,
|
||
|
num_tiles,
|
||
|
num_patches + num_padding_patches,
|
||
|
dim,
|
||
|
)
|
||
|
hidden_state = hidden_state[:, :, :slice_index]
|
||
|
|
||
|
# adding intermediate layer outputs
|
||
|
hidden_state = hidden_state.reshape(
|
||
|
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
||
|
)
|
||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||
|
batch_size * num_concurrent_media,
|
||
|
num_tiles,
|
||
|
num_patches + num_padding_patches,
|
||
|
-1,
|
||
|
)
|
||
|
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||
|
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
||
|
)
|
||
|
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
||
|
return hidden_state
|
||
|
|
||
|
|
||
|
class MllamaTextCrossAttention(nn.Module):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.num_heads = self.config.num_attention_heads
|
||
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||
|
self.dropout = config.dropout
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.head_dim = config.hidden_size // self.num_heads
|
||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||
|
|
||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||
|
config,
|
||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||
|
dim=0,
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
self.o_proj = TensorParallelRowLinear.load(
|
||
|
config,
|
||
|
prefix=f"{prefix}.o_proj",
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
|
||
|
self.q_norm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||
|
)
|
||
|
self.k_norm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||
|
past_key_value=None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
use_cache: bool = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
"""Input shape: Batch x Time x Channel"""
|
||
|
bsz, q_len, _ = hidden_states.size()
|
||
|
query_states = self.q_proj(hidden_states)
|
||
|
query_states = query_states.view(
|
||
|
bsz, q_len, self.num_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
query_states = self.q_norm(query_states)
|
||
|
|
||
|
if cross_attention_states is not None:
|
||
|
key_states = self.k_proj(cross_attention_states)
|
||
|
value_states = self.v_proj(cross_attention_states)
|
||
|
key_states = key_states.view(
|
||
|
bsz, -1, self.num_key_value_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
value_states = value_states.view(
|
||
|
bsz, -1, self.num_key_value_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||
|
|
||
|
key_states = self.k_norm(key_states)
|
||
|
if past_key_value is not None:
|
||
|
# if we have a new image + new tokens, we only computed key_states on that new image
|
||
|
# we still update the cross key states, past_image, new_image. And use it!
|
||
|
key_states, value_states = past_key_value.update(
|
||
|
key_states,
|
||
|
value_states,
|
||
|
self.layer_idx,
|
||
|
{"cache_position": cache_position},
|
||
|
)
|
||
|
elif cache_position[0] != 0:
|
||
|
key_states, value_states = (
|
||
|
past_key_value.key_cache[self.layer_idx],
|
||
|
past_key_value.value_cache[self.layer_idx],
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
||
|
)
|
||
|
|
||
|
attn_weights = torch.matmul(
|
||
|
query_states, key_states.transpose(2, 3)
|
||
|
) / math.sqrt(self.head_dim)
|
||
|
|
||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||
|
attn_weights = attn_weights + causal_mask
|
||
|
|
||
|
attn_weights = nn.functional.softmax(
|
||
|
attn_weights, dim=-1, dtype=torch.float32
|
||
|
).to(query_states.dtype)
|
||
|
attn_weights = nn.functional.dropout(
|
||
|
attn_weights, p=self.dropout, training=self.training
|
||
|
)
|
||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||
|
attn_output = self.o_proj(attn_output)
|
||
|
|
||
|
if not output_attentions:
|
||
|
attn_weights = None
|
||
|
|
||
|
return attn_output, attn_weights, past_key_value
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||
|
class MllamaTextMLP(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.intermediate_size = config.intermediate_size
|
||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||
|
config,
|
||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||
|
weights=weights,
|
||
|
dim=0,
|
||
|
bias=False,
|
||
|
)
|
||
|
self.down_proj = TensorParallelRowLinear.load(
|
||
|
config,
|
||
|
prefix=f"{prefix}.down_proj",
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
self.act_fn = ACT2FN[config.hidden_activation]
|
||
|
|
||
|
def forward(self, x):
|
||
|
gate_up_states = self.gate_up_proj(x)
|
||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||
|
|
||
|
|
||
|
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||
|
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
||
|
|
||
|
def __init__(self, *, prefix, config, weights) -> None:
|
||
|
super().__init__()
|
||
|
self.cross_attn = MllamaTextCrossAttention(
|
||
|
prefix=f"{prefix}.cross_attn", config=config, weights=weights
|
||
|
)
|
||
|
|
||
|
self.input_layernorm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||
|
)
|
||
|
self.cross_attn_attn_gate = torch.nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||
|
weights=weights,
|
||
|
eps=config.rms_norm_eps,
|
||
|
)
|
||
|
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||
|
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
cross_attention_states: torch.Tensor,
|
||
|
cross_attention_mask: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
|
||
|
past_key_value=None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
use_cache: Optional[bool] = False,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
**kwargs,
|
||
|
) -> torch.Tensor:
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.input_layernorm(hidden_states)
|
||
|
|
||
|
hidden_states, attn_weights, past_key_value = self.cross_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
attention_mask=cross_attention_mask,
|
||
|
cross_attention_states=cross_attention_states,
|
||
|
past_key_value=past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
cache_position=cache_position,
|
||
|
)
|
||
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||
|
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||
|
hidden_states = self.mlp(hidden_states)
|
||
|
if full_text_row_masked_out_mask is not None:
|
||
|
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
|
||
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||
|
|
||
|
outputs = (hidden_states,)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (attn_weights,)
|
||
|
|
||
|
if use_cache:
|
||
|
outputs += (past_key_value,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class MllamaTextSelfAttention(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.num_heads = config.num_attention_heads
|
||
|
self.dropout = config.dropout
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.num_key_value_heads = config.num_key_value_heads
|
||
|
self.head_dim = config.hidden_size // self.num_heads
|
||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||
|
|
||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||
|
config,
|
||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||
|
dim=0,
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
self.o_proj = TensorParallelRowLinear.load(
|
||
|
config,
|
||
|
prefix=f"{prefix}.o_proj",
|
||
|
weights=weights,
|
||
|
bias=False,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
position_embeddings: torch.Tensor,
|
||
|
output_attentions: bool = False,
|
||
|
use_cache: bool = False,
|
||
|
past_key_value=None,
|
||
|
cache_position=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
bsz, q_len, _ = hidden_states.size()
|
||
|
|
||
|
query_states = self.q_proj(hidden_states)
|
||
|
key_states = self.k_proj(hidden_states)
|
||
|
value_states = self.v_proj(hidden_states)
|
||
|
|
||
|
query_states = query_states.view(
|
||
|
bsz, q_len, self.num_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
key_states = key_states.view(
|
||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
value_states = value_states.view(
|
||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||
|
).transpose(1, 2)
|
||
|
|
||
|
cos, sin = position_embeddings
|
||
|
query_states, key_states = apply_rotary_pos_emb(
|
||
|
query_states, key_states, cos, sin
|
||
|
)
|
||
|
|
||
|
if past_key_value is not None:
|
||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||
|
key_states, value_states = past_key_value.update(
|
||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||
|
)
|
||
|
|
||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||
|
|
||
|
causal_mask = attention_mask
|
||
|
if attention_mask is not None:
|
||
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||
|
|
||
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||
|
query_states = query_states.contiguous()
|
||
|
key_states = key_states.contiguous()
|
||
|
value_states = value_states.contiguous()
|
||
|
|
||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
||
|
|
||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||
|
query_states,
|
||
|
key_states,
|
||
|
value_states,
|
||
|
attn_mask=causal_mask,
|
||
|
dropout_p=self.dropout if self.training else 0.0,
|
||
|
is_causal=is_causal,
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||
|
|
||
|
attn_output = self.o_proj(attn_output)
|
||
|
return attn_output, None, past_key_value
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT
|
||
|
class MllamaSelfAttentionDecoderLayer(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.hidden_size = config.hidden_size
|
||
|
|
||
|
self.self_attn = MllamaTextSelfAttention(
|
||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||
|
)
|
||
|
|
||
|
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||
|
self.input_layernorm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||
|
)
|
||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||
|
weights=weights,
|
||
|
eps=config.rms_norm_eps,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_value=None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
use_cache: Optional[bool] = False,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
position_embeddings: Optional[
|
||
|
Tuple[torch.Tensor, torch.Tensor]
|
||
|
] = None, # will become mandatory in v4.45
|
||
|
**kwargs,
|
||
|
) -> Tuple[
|
||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||
|
]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`, *optional*):
|
||
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||
|
Indices depicting the position of the input sequence tokens in the sequence
|
||
|
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||
|
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||
|
with `head_dim` being the embedding dimension of each attention head.
|
||
|
kwargs (`dict`, *optional*):
|
||
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||
|
into the model
|
||
|
"""
|
||
|
residual = hidden_states
|
||
|
|
||
|
hidden_states = self.input_layernorm(hidden_states)
|
||
|
|
||
|
# Self Attention
|
||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
position_embeddings=position_embeddings,
|
||
|
**kwargs,
|
||
|
)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
# Fully Connected
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||
|
hidden_states = self.mlp(hidden_states)
|
||
|
hidden_states = residual + hidden_states
|
||
|
|
||
|
outputs = (hidden_states,)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (self_attn_weights,)
|
||
|
|
||
|
if use_cache:
|
||
|
outputs += (present_key_value,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class MllamaTextModel(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.padding_idx = config.pad_token_id
|
||
|
self.vocab_size = config.vocab_size
|
||
|
self.embed_tokens = TensorParallelEmbedding(
|
||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||
|
)
|
||
|
self.cross_attention_layers = config.cross_attention_layers
|
||
|
|
||
|
self.layers = []
|
||
|
for layer_idx in range(config.num_hidden_layers):
|
||
|
if layer_idx in self.cross_attention_layers:
|
||
|
self.layers.append(
|
||
|
MllamaCrossAttentionDecoderLayer(
|
||
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
self.layers.append(
|
||
|
MllamaSelfAttentionDecoderLayer(
|
||
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# TODO Should we use this slow norm ?
|
||
|
# self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
self.norm = FastRMSNorm.load(
|
||
|
prefix=f"{prefix}.norm",
|
||
|
weights=weights,
|
||
|
eps=config.rms_norm_eps,
|
||
|
)
|
||
|
# TODO Anything specific ?
|
||
|
head_size = config.hidden_size // config.num_attention_heads
|
||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||
|
config=config,
|
||
|
dim=head_size,
|
||
|
base=config.rope_theta,
|
||
|
device=weights.device,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
cross_attention_states: Optional[torch.FloatTensor] = None,
|
||
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
||
|
full_text_row_masked_out_mask: Optional[
|
||
|
Tuple[torch.Tensor, torch.Tensor]
|
||
|
] = None,
|
||
|
past_key_values=None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
):
|
||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||
|
raise ValueError(
|
||
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||
|
)
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||
|
|
||
|
hidden_states = inputs_embeds
|
||
|
|
||
|
if cache_position is None:
|
||
|
past_seen_tokens = (
|
||
|
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||
|
)
|
||
|
cache_position = torch.arange(
|
||
|
past_seen_tokens,
|
||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||
|
device=inputs_embeds.device,
|
||
|
)
|
||
|
if position_ids is None:
|
||
|
position_ids = cache_position.unsqueeze(0)
|
||
|
|
||
|
# causal_mask = self._update_causal_mask(
|
||
|
# attention_mask,
|
||
|
# inputs_embeds,
|
||
|
# cache_position,
|
||
|
# past_key_values,
|
||
|
# )
|
||
|
|
||
|
# create position embeddings to be shared across the decoder layers
|
||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||
|
|
||
|
# decoder layers
|
||
|
|
||
|
for idx, decoder_layer in enumerate(self.layers):
|
||
|
if (
|
||
|
idx in self.cross_attention_layers
|
||
|
and cross_attention_states is None
|
||
|
and (
|
||
|
past_key_values is None
|
||
|
or (
|
||
|
past_key_values is not None
|
||
|
and past_key_values.get_seq_length(idx) == 0
|
||
|
)
|
||
|
)
|
||
|
):
|
||
|
continue
|
||
|
|
||
|
layer_outputs = decoder_layer(
|
||
|
hidden_states,
|
||
|
cross_attention_states=cross_attention_states,
|
||
|
cross_attention_mask=cross_attention_mask,
|
||
|
attention_mask=causal_mask,
|
||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_value=past_key_values,
|
||
|
cache_position=cache_position,
|
||
|
position_embeddings=position_embeddings,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs
|
||
|
|
||
|
hidden_states = self.norm(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
# def _update_causal_mask(
|
||
|
# self,
|
||
|
# attention_mask: torch.Tensor,
|
||
|
# input_tensor: torch.Tensor,
|
||
|
# cache_position: torch.Tensor,
|
||
|
# past_key_values,
|
||
|
# ):
|
||
|
# if self.config._attn_implementation == "flash_attention_2":
|
||
|
# if attention_mask is not None and 0.0 in attention_mask:
|
||
|
# return attention_mask
|
||
|
# return None
|
||
|
|
||
|
# # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||
|
# # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||
|
# # to infer the attention mask.
|
||
|
# past_seen_tokens = (
|
||
|
# past_key_values.get_seq_length() if past_key_values is not None else 0
|
||
|
# )
|
||
|
# using_static_cache = isinstance(past_key_values, StaticCache)
|
||
|
|
||
|
# # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||
|
# # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
|
||
|
# # self.config._attn_implementation == "sdpa" and
|
||
|
# if (
|
||
|
# self.config._attn_implementation == "sdpa"
|
||
|
# and not using_static_cache
|
||
|
# and not output_attentions
|
||
|
# ):
|
||
|
# if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||
|
# attention_mask,
|
||
|
# inputs_embeds=input_tensor,
|
||
|
# past_key_values_length=past_seen_tokens,
|
||
|
# is_training=self.training,
|
||
|
# ):
|
||
|
# return None
|
||
|
|
||
|
# dtype, device = input_tensor.dtype, input_tensor.device
|
||
|
# min_dtype = torch.finfo(dtype).min
|
||
|
# sequence_length = input_tensor.shape[1]
|
||
|
# if using_static_cache:
|
||
|
# target_length = past_key_values.get_max_length()
|
||
|
# else:
|
||
|
# target_length = (
|
||
|
# attention_mask.shape[-1]
|
||
|
# if isinstance(attention_mask, torch.Tensor)
|
||
|
# else past_seen_tokens + sequence_length + 1
|
||
|
# )
|
||
|
|
||
|
# # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||
|
# causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||
|
# attention_mask,
|
||
|
# sequence_length=sequence_length,
|
||
|
# target_length=target_length,
|
||
|
# dtype=dtype,
|
||
|
# device=device,
|
||
|
# min_dtype=min_dtype,
|
||
|
# cache_position=cache_position,
|
||
|
# batch_size=input_tensor.shape[0],
|
||
|
# )
|
||
|
|
||
|
# if (
|
||
|
# self.config._attn_implementation == "sdpa"
|
||
|
# and attention_mask is not None
|
||
|
# and attention_mask.device.type == "cuda"
|
||
|
# and not output_attentions
|
||
|
# ):
|
||
|
# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||
|
# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||
|
# # Details: https://github.com/pytorch/pytorch/issues/110213
|
||
|
# causal_mask = AttentionMaskConverter._unmask_unattended(
|
||
|
# causal_mask, min_dtype
|
||
|
# )
|
||
|
|
||
|
# return causal_mask
|
||
|
|
||
|
|
||
|
class MllamaForCausalLM(nn.Module):
|
||
|
def __init__(self, *, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
self.vocab_size = config.vocab_size
|
||
|
self.model = MllamaTextModel(
|
||
|
prefix=f"{prefix}.model", config=config, weights=weights
|
||
|
)
|
||
|
self.lm_head = SpeculativeHead.load(
|
||
|
prefix=f"{prefix}.lm_head",
|
||
|
config=config,
|
||
|
weights=weights,
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
cross_attention_states: Optional[torch.LongTensor] = None,
|
||
|
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||
|
full_text_row_masked_out_mask: Optional[
|
||
|
Tuple[torch.Tensor, torch.Tensor]
|
||
|
] = None,
|
||
|
past_key_values=None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
cache_position: Optional[torch.LongTensor] = None,
|
||
|
num_logits_to_keep: int = 0,
|
||
|
):
|
||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
outputs = self.model(
|
||
|
input_ids=input_ids,
|
||
|
cross_attention_states=cross_attention_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
cross_attention_mask=cross_attention_mask,
|
||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
cache_position=cache_position,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs
|
||
|
# if lm_head_indices is not None:
|
||
|
# hidden_states = hidden_states[lm_head_indices]
|
||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||
|
return logits
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self,
|
||
|
input_ids,
|
||
|
past_key_values=None,
|
||
|
attention_mask=None,
|
||
|
inputs_embeds=None,
|
||
|
cache_position=None,
|
||
|
position_ids=None,
|
||
|
use_cache=True,
|
||
|
num_logits_to_keep=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||
|
if past_key_values is not None:
|
||
|
if inputs_embeds is not None: # Exception 1
|
||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||
|
elif (
|
||
|
input_ids.shape[1] != cache_position.shape[0]
|
||
|
): # Default case (the "else", a no op, is Exception 2)
|
||
|
input_ids = input_ids[:, cache_position]
|
||
|
|
||
|
if attention_mask is not None and position_ids is None:
|
||
|
# create position_ids on the fly for batch generation
|
||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||
|
if past_key_values:
|
||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||
|
|
||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||
|
|
||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||
|
else:
|
||
|
# The clone here is for the same reason as for `position_ids`.
|
||
|
model_inputs = {
|
||
|
"input_ids": input_ids.clone(memory_format=torch.contiguous_format),
|
||
|
"inputs_embeds": None,
|
||
|
}
|
||
|
|
||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||
|
if model_inputs["inputs_embeds"] is not None:
|
||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||
|
device = model_inputs["inputs_embeds"].device
|
||
|
else:
|
||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||
|
device = model_inputs["input_ids"].device
|
||
|
|
||
|
dtype = self.lm_head.weight.dtype
|
||
|
min_dtype = torch.finfo(dtype).min
|
||
|
|
||
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||
|
attention_mask,
|
||
|
sequence_length=sequence_length,
|
||
|
target_length=past_key_values.get_max_length(),
|
||
|
dtype=dtype,
|
||
|
device=device,
|
||
|
min_dtype=min_dtype,
|
||
|
cache_position=cache_position,
|
||
|
batch_size=batch_size,
|
||
|
)
|
||
|
|
||
|
if num_logits_to_keep is not None:
|
||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||
|
|
||
|
model_inputs.update(
|
||
|
{
|
||
|
"position_ids": position_ids,
|
||
|
"cache_position": cache_position,
|
||
|
"past_key_values": past_key_values,
|
||
|
"use_cache": use_cache,
|
||
|
"attention_mask": attention_mask,
|
||
|
}
|
||
|
)
|
||
|
return model_inputs
|
||
|
|
||
|
|
||
|
class MllamaForConditionalGeneration(nn.Module):
|
||
|
def __init__(self, prefix, config, weights):
|
||
|
super().__init__()
|
||
|
config.vision_config.quantize = None
|
||
|
config.vision_config.speculator = config.speculator
|
||
|
config.text_config.quantize = config.quantize
|
||
|
config.text_config.speculator = config.speculator
|
||
|
self.vision_model = MllamaVisionModel(
|
||
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||
|
)
|
||
|
self.language_model = MllamaForCausalLM(
|
||
|
prefix="language_model", config=config.text_config, weights=weights
|
||
|
)
|
||
|
self.multi_modal_projector = FastLinear.load(
|
||
|
prefix="multi_modal_projector", config=config, weights=weights, bias=True
|
||
|
)
|
||
|
self.config = config
|