mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
830 lines
31 KiB
Python
830 lines
31 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Idefics2 model."""
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
import math
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.image_processing_utils import select_best_resolution
|
|
from text_generation_server.models.custom_modeling.vlm import (
|
|
load_text_model,
|
|
load_vision_model,
|
|
)
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
)
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
|
)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
class Idefics2VisionEmbeddings(nn.Module):
|
|
"""
|
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
|
resolution.
|
|
|
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
|
"""
|
|
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
padding="valid",
|
|
)
|
|
self.patch_embedding.weight = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
|
)
|
|
self.patch_embedding.bias = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
|
)
|
|
|
|
self.num_patches_per_side = self.image_size // self.patch_size
|
|
self.num_patches = self.num_patches_per_side**2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = TensorParallelEmbedding(
|
|
prefix=f"{prefix}.position_embedding", weights=weights
|
|
)
|
|
|
|
def forward(
|
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
|
) -> torch.Tensor:
|
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
|
|
|
patch_embeds = self.patch_embedding(pixel_values)
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
max_nb_patches_h, max_nb_patches_w = (
|
|
max_im_h // self.patch_size,
|
|
max_im_w // self.patch_size,
|
|
)
|
|
boundaries = torch.arange(
|
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
|
)
|
|
position_ids = torch.full(
|
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
|
)
|
|
|
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
|
nb_patches_w = p_attn_mask[0].sum()
|
|
|
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
|
|
|
bucket_coords_h = torch.bucketize(
|
|
fractional_coords_h, boundaries, right=True
|
|
)
|
|
bucket_coords_w = torch.bucketize(
|
|
fractional_coords_w, boundaries, right=True
|
|
)
|
|
|
|
pos_ids = (
|
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
|
).flatten()
|
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
|
|
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
|
embeddings = embeddings + self.position_embedding(position_ids)
|
|
return embeddings
|
|
|
|
|
|
class Idefics2VisionAttention(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_size = self.embed_dim // self.num_heads
|
|
if self.head_size * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_size**-0.5
|
|
self.dropout = config.attention_dropout
|
|
|
|
self.num_heads = self.num_heads // weights.process_group.size()
|
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
|
|
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
dim=0,
|
|
weights=weights,
|
|
bias=True,
|
|
)
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
|
)
|
|
self.is_causal = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
batch_size, q_len, _ = hidden_states.size()
|
|
|
|
qkv = self.qkv(hidden_states)
|
|
query_states, key_states, value_states = qkv.split(
|
|
[
|
|
self.head_size * self.num_heads,
|
|
self.head_size * self.num_heads,
|
|
self.head_size * self.num_heads,
|
|
],
|
|
dim=2,
|
|
)
|
|
|
|
query_states = query_states.view(
|
|
batch_size, q_len, self.num_heads, self.head_size
|
|
).transpose(1, 2)
|
|
key_states = key_states.view(
|
|
batch_size, q_len, self.num_heads, self.head_size
|
|
).transpose(1, 2)
|
|
value_states = value_states.view(
|
|
batch_size, q_len, self.num_heads, self.head_size
|
|
).transpose(1, 2)
|
|
|
|
k_v_seq_len = key_states.shape[-2]
|
|
attn_weights = (
|
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
|
)
|
|
|
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(query_states.dtype)
|
|
attn_weights = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=self.training
|
|
)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class Idefics2VisionMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = TensorParallelColumnLinear.load(
|
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
|
)
|
|
self.fc2 = TensorParallelRowLinear.load(
|
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2EncoderLayer(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = Idefics2VisionAttention(
|
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
)
|
|
self.layer_norm1 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
|
)
|
|
self.layer_norm2 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
|
)
|
|
self.mlp = Idefics2VisionMLP(
|
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
|
)
|
|
|
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2Encoder(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
Idefics2EncoderLayer(
|
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
# Ignore copy
|
|
def forward(
|
|
self,
|
|
inputs_embeds,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
hidden_states = inputs_embeds
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
)
|
|
return hidden_states
|
|
|
|
|
|
class Idefics2VisionTransformer(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embeddings = Idefics2VisionEmbeddings(
|
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
|
)
|
|
self.encoder = Idefics2Encoder(
|
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
|
)
|
|
self.post_layernorm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.post_layernorm",
|
|
weights=weights,
|
|
eps=config.layer_norm_eps,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values,
|
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
|
):
|
|
batch_size = pixel_values.size(0)
|
|
if patch_attention_mask is None:
|
|
patch_size = self.config.patch_size
|
|
patch_attention_mask = torch.ones(
|
|
(
|
|
batch_size,
|
|
pixel_values.size(2) // patch_size,
|
|
pixel_values.size(3) // patch_size,
|
|
)
|
|
)
|
|
patch_attention_mask = patch_attention_mask.to(
|
|
dtype=torch.bool, device=pixel_values.device
|
|
)
|
|
|
|
hidden_states = self.embeddings(
|
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
|
)
|
|
|
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
|
if not torch.any(~patch_attention_mask):
|
|
patch_attention_mask = None
|
|
else:
|
|
patch_attention_mask = _prepare_4d_attention_mask(
|
|
patch_attention_mask, hidden_states.dtype
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=hidden_states,
|
|
attention_mask=patch_attention_mask,
|
|
)
|
|
|
|
last_hidden_state = encoder_outputs
|
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
|
|
return last_hidden_state
|
|
|
|
|
|
class Idefics2MLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
act = config.text_config.hidden_act
|
|
self.act = (
|
|
ACT2FN[act]
|
|
if "gelu" not in act
|
|
else lambda x: torch.nn.functional.gelu(
|
|
x,
|
|
approximate=(
|
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
),
|
|
)
|
|
)
|
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
|
weights=weights,
|
|
dim=0,
|
|
bias=False,
|
|
)
|
|
self.down_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.down_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
start_shape = hidden_states.shape[:-1]
|
|
gate_up_states = self.gate_up_proj(hidden_states)
|
|
intermediate_size = gate_up_states.shape[-1] // 2
|
|
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
|
return self.down_proj(
|
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
|
).view(*start_shape, -1)
|
|
|
|
|
|
class Idefics2RMSNorm(nn.Module):
|
|
def __init__(self, prefix, weights, eps):
|
|
"""
|
|
Idefics2RMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
self.weight = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
|
)
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class Idefics2PerceiverAttention(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
|
|
self.layer_idx = None
|
|
self.hidden_size = config.text_config.hidden_size
|
|
self.num_heads = config.perceiver_config.resampler_n_heads
|
|
self.head_size = config.perceiver_config.resampler_head_dim
|
|
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.attention_dropout = config.perceiver_config.attention_dropout
|
|
self.num_heads = self.num_heads // weights.process_group.size()
|
|
self.num_key_value_heads = (
|
|
self.num_key_value_heads // weights.process_group.size()
|
|
)
|
|
|
|
self.q_proj = TensorParallelColumnLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.q_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.kv = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
dim=0,
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.o_proj = TensorParallelRowLinear.load(
|
|
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
|
)
|
|
|
|
self.is_causal = False
|
|
|
|
def forward(
|
|
self,
|
|
latents: torch.Tensor,
|
|
context: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = latents.size()
|
|
kv_seq_len = q_len + context.size()[1]
|
|
|
|
hidden_states = torch.concat([context, latents], dim=-2)
|
|
query_states = self.q_proj(latents)
|
|
kv = self.kv(hidden_states)
|
|
key_states, value_states = kv.split(
|
|
[
|
|
self.head_size * self.num_key_value_heads,
|
|
self.head_size * self.num_key_value_heads,
|
|
],
|
|
dim=2,
|
|
)
|
|
|
|
query_states = query_states.view(
|
|
bsz, q_len, self.num_heads, self.head_size
|
|
).transpose(1, 2)
|
|
key_states = key_states.view(
|
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
|
).transpose(1, 2)
|
|
value_states = value_states.view(
|
|
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
|
).transpose(1, 2)
|
|
|
|
# repeat k/v heads if n_kv_heads < n_heads
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(
|
|
query_states, key_states.transpose(2, 3)
|
|
) / math.sqrt(self.head_size)
|
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
)
|
|
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(query_states.dtype)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class Idefics2PerceiverLayer(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.hidden_size = config.text_config.hidden_size
|
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
|
self.depth = config.perceiver_config.resampler_depth
|
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
|
|
|
self.input_latents_norm = Idefics2RMSNorm(
|
|
prefix=f"{prefix}.input_latents_norm",
|
|
weights=weights,
|
|
eps=self.rms_norm_eps,
|
|
)
|
|
self.input_context_norm = Idefics2RMSNorm(
|
|
prefix=f"{prefix}.input_context_norm",
|
|
weights=weights,
|
|
eps=self.rms_norm_eps,
|
|
)
|
|
self.self_attn = Idefics2PerceiverAttention(
|
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
)
|
|
self.post_attention_layernorm = Idefics2RMSNorm(
|
|
prefix=f"{prefix}.post_attention_layernorm",
|
|
weights=weights,
|
|
eps=self.rms_norm_eps,
|
|
)
|
|
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
|
|
def forward(
|
|
self,
|
|
latents: torch.Tensor,
|
|
context: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
"""
|
|
residual = latents
|
|
|
|
latents = self.input_latents_norm(latents)
|
|
context = self.input_context_norm(context)
|
|
|
|
latents = self.self_attn(
|
|
latents=latents,
|
|
context=context,
|
|
attention_mask=attention_mask,
|
|
)
|
|
latents = residual + latents
|
|
residual = latents
|
|
|
|
latents = self.post_attention_layernorm(latents)
|
|
latents = self.mlp(latents)
|
|
latents = residual + latents
|
|
|
|
return latents
|
|
|
|
|
|
class Idefics2PerceiverResampler(nn.Module):
|
|
def __init__(self, prefix, config, weights) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.text_config.hidden_size
|
|
self.hidden_act = config.perceiver_config.hidden_act
|
|
self.n_latents = config.perceiver_config.resampler_n_latents
|
|
self.depth = config.perceiver_config.resampler_depth
|
|
self.rms_norm_eps = config.text_config.rms_norm_eps
|
|
|
|
# Create Latents for Perceiver
|
|
self.latents = weights.get_tensor(f"{prefix}.latents")
|
|
|
|
# Create Transformer Blocks
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
Idefics2PerceiverLayer(
|
|
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
|
)
|
|
for idx in range(self.depth)
|
|
]
|
|
)
|
|
self.norm = Idefics2RMSNorm(
|
|
prefix=f"{prefix}.norm",
|
|
weights=weights,
|
|
eps=config.text_config.rms_norm_eps,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
context: torch.Tensor,
|
|
attention_mask,
|
|
) -> torch.Tensor:
|
|
# seq embed -> bsz seq embed
|
|
latents = self.latents.unsqueeze(0).expand(
|
|
(context.shape[0], *self.latents.size())
|
|
)
|
|
|
|
latent_attention_mask = torch.ones(
|
|
(attention_mask.size(0), latents.size(1)),
|
|
dtype=attention_mask.dtype,
|
|
device=attention_mask.device,
|
|
)
|
|
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
|
attention_mask = _prepare_4d_attention_mask(
|
|
attention_mask, latents.dtype, tgt_len=self.n_latents
|
|
)
|
|
|
|
compressed_context = latents
|
|
for perceiver_layer in self.layers:
|
|
compressed_context = perceiver_layer(
|
|
compressed_context,
|
|
context,
|
|
attention_mask=attention_mask,
|
|
)
|
|
compressed_context = self.norm(compressed_context)
|
|
|
|
return compressed_context
|
|
|
|
|
|
class Idefics2Connector(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
self.modality_projection = Idefics2MLP(
|
|
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
|
)
|
|
self.perceiver_resampler = Idefics2PerceiverResampler(
|
|
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
|
)
|
|
|
|
def forward(self, image_hidden_states, attention_mask):
|
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
|
image_hidden_states = self.perceiver_resampler(
|
|
context=image_hidden_states, attention_mask=attention_mask
|
|
)
|
|
return image_hidden_states
|
|
|
|
|
|
class Idefics2ForConditionalGeneration(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
config.vision_config.quantize = config.quantize
|
|
config.vision_config.speculator = config.speculator
|
|
config.text_config.quantize = config.quantize
|
|
config.text_config.speculator = config.speculator
|
|
|
|
vision_config = config.vision_config
|
|
self.text_model = load_text_model(
|
|
prefix="model" if not prefix else f"{prefix}.model",
|
|
config=config.text_config,
|
|
weights=weights,
|
|
name="text_model",
|
|
)
|
|
self.dtype = weights.dtype
|
|
self.vision_model = Idefics2VisionTransformer(
|
|
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
|
config=vision_config,
|
|
weights=weights,
|
|
)
|
|
self.connector = Idefics2Connector(
|
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
self.config = config
|
|
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
|
self.image_token_id = config.image_token_id
|
|
self.pad_token_id = (
|
|
config.pad_token_id if config.pad_token_id is not None else -1
|
|
)
|
|
|
|
def _merge_input_ids_with_image_features(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
inputs_embeds: torch.Tensor,
|
|
image_features: torch.Tensor,
|
|
):
|
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
|
# mask = input_ids == self.config.image_token_index
|
|
mask = input_ids == self.config.image_token_id
|
|
# Let's pray we have enabled enough slots !
|
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
block_tables: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
input_lengths: torch.Tensor,
|
|
max_s: int,
|
|
prefill_cache_indices: Optional[torch.Tensor],
|
|
lm_head_indices: Optional[torch.Tensor] = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
|
# Unused here
|
|
image_sizes: Optional[torch.Tensor] = None,
|
|
):
|
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
|
if pixel_values is not None:
|
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
|
all_states = []
|
|
all_pixel_values = pixel_values
|
|
all_pixel_mask = pixel_attention_mask
|
|
for i in range(batch_size):
|
|
pixel_values = all_pixel_values.to(
|
|
dtype=self.dtype
|
|
) # fp16 compatibility
|
|
pixel_values = pixel_values[i : i + 1]
|
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
|
|
|
# Remove padding images - padding images are full 0.
|
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
real_images_inds = (pixel_values == 0.0).sum(
|
|
dim=(-1, -2, -3)
|
|
) != nb_values_per_image
|
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
|
|
# Handle the vision attention mask
|
|
if pixel_attention_mask is None:
|
|
pixel_attention_mask = torch.ones(
|
|
size=(
|
|
pixel_values.size(0),
|
|
pixel_values.size(2),
|
|
pixel_values.size(3),
|
|
),
|
|
dtype=torch.bool,
|
|
device=pixel_values.device,
|
|
)
|
|
else:
|
|
# Remove padding images from the mask/pP p
|
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
|
pixel_attention_mask = pixel_attention_mask.view(
|
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
|
)
|
|
pixel_attention_mask = pixel_attention_mask[
|
|
real_images_inds
|
|
].contiguous()
|
|
|
|
patch_size = self.config.vision_config.patch_size
|
|
patches_subgrid = pixel_attention_mask.unfold(
|
|
dimension=1, size=patch_size, step=patch_size
|
|
)
|
|
patches_subgrid = patches_subgrid.unfold(
|
|
dimension=2, size=patch_size, step=patch_size
|
|
)
|
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
|
|
# Get sequence from the vision encoder
|
|
image_hidden_states = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
patch_attention_mask=patch_attention_mask,
|
|
)
|
|
|
|
# Modality projection & resampling
|
|
image_hidden_states = self.connector(
|
|
image_hidden_states,
|
|
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
|
)
|
|
all_states.append(image_hidden_states)
|
|
image_hidden_states = torch.stack(all_states, dim=0)
|
|
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
|
# that simply don't exist
|
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
|
input_ids, inputs_embeds, image_hidden_states
|
|
)
|
|
|
|
hidden_states = self.text_model.model(
|
|
inputs_embeds=inputs_embeds,
|
|
position_ids=position_ids,
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
kv_cache=kv_cache,
|
|
block_tables=block_tables,
|
|
slots=slots,
|
|
input_lengths=input_lengths,
|
|
max_s=max_s,
|
|
true_max_s=max_s,
|
|
prefill_cache_indices=None,
|
|
)
|
|
if lm_head_indices is not None:
|
|
hidden_states = hidden_states[lm_head_indices]
|
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
|
return logits, speculative_logits
|