text-generation-inference/server/text_generation_server/models/custom_modeling/idefics2.py
Nicolas Patry e3d765645a
MLPSpeculator. (#1865)
# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Joshua Rosenkranz <joshua.rosenkranz@gmail.com>
2024-05-14 12:33:18 +02:00

830 lines
31 KiB
Python

# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
import math
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics2VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics2VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics2Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.text_config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
def forward(self, hidden_states):
start_shape = hidden_states.shape[:-1]
gate_up_states = self.gate_up_proj(hidden_states)
intermediate_size = gate_up_states.shape[-1] // 2
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
).view(*start_shape, -1)
class Idefics2RMSNorm(nn.Module):
def __init__(self, prefix, weights, eps):
"""
Idefics2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.layer_idx = None
self.hidden_size = config.text_config.hidden_size
self.num_heads = config.perceiver_config.resampler_n_heads
self.head_size = config.perceiver_config.resampler_head_dim
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.perceiver_config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.kv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
kv = self.kv(hidden_states)
key_states, value_states = kv.split(
[
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_size)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
attn_output = self.o_proj(attn_output)
return attn_output
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
self.input_latents_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_latents_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.input_context_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_context_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.self_attn = Idefics2PerceiverAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.post_attention_layernorm = Idefics2RMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=self.rms_norm_eps,
)
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
"""
residual = latents
latents = self.input_latents_norm(latents)
context = self.input_context_norm(context)
latents = self.self_attn(
latents=latents,
context=context,
attention_mask=attention_mask,
)
latents = residual + latents
residual = latents
latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = residual + latents
return latents
class Idefics2PerceiverResampler(nn.Module):
def __init__(self, prefix, config, weights) -> None:
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.hidden_act = config.perceiver_config.hidden_act
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
# Create Latents for Perceiver
self.latents = weights.get_tensor(f"{prefix}.latents")
# Create Transformer Blocks
self.layers = nn.ModuleList(
[
Idefics2PerceiverLayer(
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
)
for idx in range(self.depth)
]
)
self.norm = Idefics2RMSNorm(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.text_config.rms_norm_eps,
)
def forward(
self,
context: torch.Tensor,
attention_mask,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand(
(context.shape[0], *self.latents.size())
)
latent_attention_mask = torch.ones(
(attention_mask.size(0), latents.size(1)),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = _prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
compressed_context = latents
for perceiver_layer in self.layers:
compressed_context = perceiver_layer(
compressed_context,
context,
attention_mask=attention_mask,
)
compressed_context = self.norm(compressed_context)
return compressed_context
class Idefics2Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics2MLP(
prefix=f"{prefix}.modality_projection", config=config, weights=weights
)
self.perceiver_resampler = Idefics2PerceiverResampler(
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
)
def forward(self, image_hidden_states, attention_mask):
image_hidden_states = self.modality_projection(image_hidden_states)
image_hidden_states = self.perceiver_resampler(
context=image_hidden_states, attention_mask=attention_mask
)
return image_hidden_states
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
config=vision_config,
weights=weights,
)
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits