mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
1557 lines
60 KiB
Python
1557 lines
60 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Idefics model."""
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers import PreTrainedModel
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
CausalLMOutputWithPast,
|
|
dataclass,
|
|
)
|
|
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
|
from text_generation_server.models.custom_modeling.idefics_vision import (
|
|
IdeficsVisionTransformer,
|
|
)
|
|
from text_generation_server.models.custom_modeling.idefics_perceiver import (
|
|
IdeficsPerceiverResampler,
|
|
)
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
SpeculativeHead,
|
|
FastLinear,
|
|
)
|
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from loguru import logger
|
|
|
|
if SYSTEM == "cuda":
|
|
import dropout_layer_norm
|
|
elif SYSTEM == "rocm":
|
|
from vllm._C import ops
|
|
else:
|
|
dropout_layer_norm = None
|
|
|
|
|
|
@dataclass
|
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
@dataclass
|
|
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
|
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
# logger = logging.get_logger(__name__)
|
|
|
|
# _CONFIG_FOR_DOC = "IdeficsConfig"
|
|
|
|
# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
# "HuggingFaceM4/idefics-9b",
|
|
# "HuggingFaceM4/idefics-80b",
|
|
# # See all Idefics models at https://huggingface.co/models?filter=idefics
|
|
# ]
|
|
|
|
|
|
def expand_inputs_for_generation(
|
|
input_ids,
|
|
expand_size=1,
|
|
is_encoder_decoder=False,
|
|
attention_mask=None,
|
|
encoder_outputs=None,
|
|
**model_kwargs,
|
|
):
|
|
expanded_return_idx = (
|
|
torch.arange(input_ids.shape[0])
|
|
.view(-1, 1)
|
|
.repeat(1, expand_size)
|
|
.view(-1)
|
|
.to(input_ids.device)
|
|
)
|
|
input_ids = input_ids.index_select(0, expanded_return_idx)
|
|
|
|
if "token_type_ids" in model_kwargs:
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
model_kwargs["token_type_ids"] = token_type_ids.index_select(
|
|
0, expanded_return_idx
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
model_kwargs["attention_mask"] = attention_mask.index_select(
|
|
0, expanded_return_idx
|
|
)
|
|
model_kwargs["image_attention_mask"] = model_kwargs[
|
|
"image_attention_mask"
|
|
].index_select(0, expanded_return_idx)
|
|
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(
|
|
0, expanded_return_idx
|
|
)
|
|
|
|
if is_encoder_decoder:
|
|
if encoder_outputs is None:
|
|
raise ValueError(
|
|
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
|
)
|
|
encoder_outputs["last_hidden_state"] = (
|
|
encoder_outputs.last_hidden_state.index_select(
|
|
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
|
)
|
|
)
|
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
|
return input_ids, model_kwargs
|
|
|
|
|
|
def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
|
# must have this key set to at least None
|
|
model_kwargs["past_key_values"] = model_kwargs.get("past_key_values", None)
|
|
|
|
# update past
|
|
if "past_key_values" in outputs:
|
|
model_kwargs["past"] = outputs.past_key_values
|
|
elif "mems" in outputs:
|
|
model_kwargs["past"] = outputs.mems
|
|
elif "past_buckets_states" in outputs:
|
|
model_kwargs["past"] = outputs.past_buckets_states
|
|
else:
|
|
model_kwargs["past"] = None
|
|
|
|
# update token_type_ids with last value
|
|
if "token_type_ids" in model_kwargs:
|
|
token_type_ids = model_kwargs["token_type_ids"]
|
|
model_kwargs["token_type_ids"] = torch.cat(
|
|
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
|
|
)
|
|
|
|
# update attention masks
|
|
if not is_encoder_decoder:
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
|
|
dim=-1,
|
|
)
|
|
if "image_attention_mask" in model_kwargs:
|
|
image_attention_mask = model_kwargs["image_attention_mask"]
|
|
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
|
|
model_kwargs["image_attention_mask"] = last_mask
|
|
|
|
return model_kwargs
|
|
|
|
|
|
def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
|
|
token_type_ids = kwargs.get("token_type_ids", None)
|
|
# only last token for inputs_ids if past is defined in kwargs
|
|
if past:
|
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
if token_type_ids is not None:
|
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
|
|
|
attention_mask = kwargs.get("attention_mask", None)
|
|
position_ids = kwargs.get("position_ids", None)
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
# create position_ids on the fly for batch generation
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
if past:
|
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
|
pixel_values = kwargs.get("pixel_values", None)
|
|
image_attention_mask = kwargs.get("image_attention_mask", None)
|
|
# if pixel_values is None or image_attention_mask is None:
|
|
# raise ValueError("pixel values and image attention mask cannot be None")
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"past_key_values": past,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"position_ids": position_ids,
|
|
"attention_mask": attention_mask,
|
|
"token_type_ids": token_type_ids,
|
|
"pixel_values": pixel_values,
|
|
"image_attention_mask": image_attention_mask,
|
|
}
|
|
|
|
|
|
def freeze_model(model, module_exceptions=[]):
|
|
mapping = {
|
|
"LayerNorm": nn.LayerNorm,
|
|
"Linear": nn.Linear,
|
|
"Embedding": nn.Embedding,
|
|
}
|
|
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
|
|
for module in model.modules():
|
|
if module_exceptions and any(
|
|
[isinstance(module, t) for t in module_exceptions_mapped]
|
|
):
|
|
module.requires_grad_(
|
|
True
|
|
) # Explicitely setting it to true to avoid any mistakes
|
|
else:
|
|
module.requires_grad_(False)
|
|
return model
|
|
|
|
|
|
class IdeficsDecoupledPartialTPEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
weights,
|
|
):
|
|
super().__init__()
|
|
self.num_embeddings = config.vocab_size
|
|
self.weight = TensorParallelEmbedding(
|
|
prefix="model.embed_tokens", weights=weights
|
|
)
|
|
self.additional_weight = nn.Parameter(
|
|
weights.get_tensor("model.embed_tokens.additional_embedding.weight")
|
|
)
|
|
|
|
def forward(self, input_ids):
|
|
# Clone so that we don't modify the original input_ids later on
|
|
input_ids = input_ids.clone()
|
|
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
|
|
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
|
additional_embeddings = torch.nn.functional.embedding(
|
|
input_ids_additional_vocab - self.num_embeddings, self.additional_weight
|
|
)
|
|
|
|
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
|
input_ids[additional_vocab_indices] = 0
|
|
full_vector = self.weight(input_ids)
|
|
|
|
# overwrite the records with high indices
|
|
full_vector[additional_vocab_indices] = additional_embeddings
|
|
|
|
return full_vector
|
|
|
|
|
|
class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
|
|
"""
|
|
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
|
|
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
|
|
then it will create `out_additional_features * in_features` additional parameters that are always trained. If
|
|
`out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
weights,
|
|
) -> None:
|
|
super().__init__()
|
|
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
|
self.additional_fc = FastLinear.load(
|
|
config=config,
|
|
prefix="lm_head.additional_fc",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
output, speculative_logits = self.fc(input)
|
|
additional_features = self.additional_fc(input)
|
|
output = torch.cat((output, additional_features), -1)
|
|
|
|
return output, speculative_logits
|
|
|
|
def extra_repr(self) -> str:
|
|
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
|
return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
|
|
self.in_features,
|
|
self.out_features,
|
|
self.out_additional_features,
|
|
self.bias is not None,
|
|
self.partially_freeze,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
def _make_causal_mask(
|
|
input_ids_shape: torch.Size,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
past_key_values_length: int = 0,
|
|
):
|
|
"""
|
|
Make causal mask used for bi-directional self-attention.
|
|
"""
|
|
bsz, tgt_len = input_ids_shape
|
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
mask = mask.to(dtype)
|
|
|
|
if past_key_values_length > 0:
|
|
mask = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
),
|
|
mask,
|
|
],
|
|
dim=-1,
|
|
)
|
|
return mask[None, None, :, :].expand(
|
|
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
)
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
"""
|
|
bsz, src_len = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
return inverted_mask.masked_fill(
|
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
)
|
|
|
|
|
|
class IdeficsRMSNorm(nn.Module):
|
|
def __init__(self, prefix, weights, eps=1e-6):
|
|
"""
|
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states, residual=None):
|
|
if SYSTEM == "ipex":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
out = ipex.llm.functional.add_rms_norm(
|
|
residual,
|
|
hidden_states,
|
|
self.weight,
|
|
None,
|
|
self.variance_epsilon,
|
|
residual is not None,
|
|
)
|
|
return out
|
|
elif hidden_states.shape[-1] > 8192:
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(
|
|
variance + self.variance_epsilon
|
|
)
|
|
|
|
# convert into half-precision if necessary
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
hidden_states = hidden_states.to(self.weight.dtype)
|
|
|
|
return self.weight * hidden_states
|
|
elif SYSTEM == "cuda":
|
|
# faster post attention rms norm
|
|
unwrap = False
|
|
if len(hidden_states.shape) > 2:
|
|
unwrap = True
|
|
shape = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
|
|
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
hidden_states,
|
|
residual,
|
|
self.weight,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0.0,
|
|
self.variance_epsilon,
|
|
1.0,
|
|
0,
|
|
None,
|
|
False,
|
|
True, # Activate RMSNorm
|
|
)
|
|
if res is None:
|
|
res = hidden_states
|
|
|
|
if unwrap:
|
|
normed_hidden_states = normed_hidden_states.view(*shape)
|
|
|
|
return normed_hidden_states
|
|
elif SYSTEM == "rocm":
|
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
unwrap = False
|
|
if len(hidden_states.shape) > 2:
|
|
unwrap = True
|
|
shape = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
|
|
|
out = torch.empty_like(hidden_states)
|
|
ops.rms_norm(
|
|
out,
|
|
hidden_states,
|
|
self.weight.data,
|
|
self.variance_epsilon,
|
|
)
|
|
|
|
if unwrap:
|
|
out = out.view(*shape)
|
|
|
|
return out
|
|
else:
|
|
raise ValueError(
|
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
|
)
|
|
|
|
|
|
# this was adapted from LlamaMLP
|
|
class IdeficsMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix,
|
|
weights,
|
|
):
|
|
super().__init__()
|
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
|
weights=weights,
|
|
dim=0,
|
|
bias=False,
|
|
)
|
|
self.down_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.down_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, hidden_states):
|
|
gate_up_states = self.gate_up_proj(hidden_states)
|
|
shape = gate_up_states.shape
|
|
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
|
|
return self.down_proj(
|
|
self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
|
|
)
|
|
|
|
|
|
# this was adapted from LlamaAttention
|
|
class IdeficsAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix,
|
|
weights,
|
|
qk_layer_norms: bool = False,
|
|
is_cross_attention: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.dropout = config.dropout
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
self.is_cross_attention = is_cross_attention
|
|
|
|
# if not hasattr(nn.functional, "scaled_dot_product_attention"):
|
|
# raise ValueError("this model requires pytorch 2.0 or higher")
|
|
|
|
if self.num_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.num_heads //= weights.process_group.size()
|
|
|
|
if self.is_cross_attention:
|
|
# kv_input_dim = (
|
|
# self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
|
|
# )
|
|
self.q_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
|
)
|
|
self.k_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
|
)
|
|
self.v_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
|
)
|
|
else:
|
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
dim=0,
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.o_proj = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
|
)
|
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
|
config=config, dim=self.head_dim, base=10000.0, device=weights.device
|
|
)
|
|
self.qk_layer_norms = qk_layer_norms
|
|
if self.qk_layer_norms:
|
|
self.q_layer_norm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.q_layer_norm",
|
|
weights=weights,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.k_layer_norm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.q_layer_norm",
|
|
weights=weights,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return (
|
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
if is_cross_attention:
|
|
query_states = self.q_proj(hidden_states).view(
|
|
bsz, q_len, self.num_heads, self.head_dim
|
|
) # .transpose(1, 2)
|
|
query_states = query_states.transpose(1, 2)
|
|
(
|
|
_,
|
|
kv_len,
|
|
_,
|
|
) = (
|
|
key_value_states.size()
|
|
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
key_states = (
|
|
self.k_proj(key_value_states)
|
|
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
value_states = (
|
|
self.v_proj(key_value_states)
|
|
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
)
|
|
else:
|
|
qkv = self.qkv(hidden_states)
|
|
query_states, key_states, value_states = qkv.split(
|
|
self.num_heads * self.head_dim, dim=2
|
|
)
|
|
|
|
query_states = query_states.view(
|
|
bsz, q_len, self.num_heads, self.head_dim
|
|
) # .transpose(1, 2)
|
|
key_states = key_states.view(
|
|
bsz, q_len, self.num_heads, self.head_dim
|
|
) # . transpose(1, 2)
|
|
value_states = value_states.view(
|
|
bsz, q_len, self.num_heads, self.head_dim
|
|
) # .transpose(1, 2)
|
|
kv_seq_len = q_len
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
max_s = max(kv_seq_len, q_len)
|
|
cos, sin = self.rotary_emb.get_cos_sin(
|
|
position_ids.view(-1), max_s, hidden_states.dtype
|
|
)
|
|
|
|
query_shape = query_states.shape
|
|
key_shape = key_states.shape
|
|
self.rotary_emb(
|
|
query_states.view(-1, *query_shape[2:]),
|
|
key_states.reshape(-1, *key_shape[2:]),
|
|
cos,
|
|
sin,
|
|
)
|
|
|
|
query_states = query_states.view(query_shape)
|
|
key_states = key_states.view(key_shape)
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
kv_seq_len += past_key_value[0].shape[-2]
|
|
# [bsz, nh, t, hd]
|
|
|
|
if past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
if self.qk_layer_norms:
|
|
query_states = self.q_layer_norm(query_states)
|
|
key_states = self.k_layer_norm(key_states)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
)
|
|
|
|
attn_output = nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=attention_mask,
|
|
dropout_p=self.dropout,
|
|
)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
attn_weights = None
|
|
if output_attentions:
|
|
logger.warning_once(
|
|
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
|
|
)
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
# this was adapted from LlamaDecoderLayer
|
|
class IdeficsDecoderLayer(nn.Module):
|
|
def __init__(self, layer_id: int, config: IdeficsConfig, weights):
|
|
super().__init__()
|
|
self.process_group = weights.process_group
|
|
self.hidden_size = config.hidden_size
|
|
prefix = f"model.layers.{layer_id}"
|
|
self.self_attn = IdeficsAttention(
|
|
config=config,
|
|
prefix=f"{prefix}.self_attn",
|
|
weights=weights,
|
|
qk_layer_norms=False,
|
|
is_cross_attention=False,
|
|
)
|
|
self.mlp = IdeficsMLP(
|
|
config=config,
|
|
prefix=f"{prefix}.mlp",
|
|
weights=weights,
|
|
)
|
|
self.input_layernorm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
|
)
|
|
self.post_attention_layernorm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.post_attention_layernorm",
|
|
weights=weights,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.dropout = config.dropout
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
) -> Tuple[
|
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
"""
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
class IdeficsGatedCrossAttentionLayer(nn.Module):
|
|
def __init__(self, layer_id, config: IdeficsConfig, weights):
|
|
super().__init__()
|
|
self.process_group = weights.process_group
|
|
self.hidden_size = config.hidden_size
|
|
prefix = f"model.gated_cross_attn_layers.{layer_id}"
|
|
self.cross_attn = IdeficsAttention(
|
|
config=config,
|
|
prefix=f"{prefix}.cross_attn",
|
|
weights=weights,
|
|
qk_layer_norms=True,
|
|
is_cross_attention=True,
|
|
)
|
|
self.mlp = IdeficsMLP(
|
|
config=config,
|
|
prefix=f"{prefix}.mlp",
|
|
weights=weights,
|
|
)
|
|
self.input_layernorm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
|
)
|
|
self.post_attention_layernorm = IdeficsRMSNorm(
|
|
prefix=f"{prefix}.post_attention_layernorm",
|
|
weights=weights,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.config = config.dropout
|
|
|
|
self.act_cross_attn = nn.Tanh()
|
|
self.act_dense = nn.Tanh()
|
|
|
|
self.alpha_cross_attn = nn.Parameter(
|
|
weights.get_tensor(f"{prefix}.alpha_cross_attn")
|
|
)
|
|
self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
|
|
|
|
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
|
|
raise ValueError("Alpha parameters not initialized correctly!")
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
image_hidden_states: Optional[torch.Tensor] = None,
|
|
image_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
no_images: Optional[bool] = False,
|
|
) -> Tuple[
|
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
|
|
"""
|
|
if image_hidden_states is None:
|
|
raise ValueError(
|
|
"`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
|
|
" conditioned on."
|
|
)
|
|
|
|
if past_key_value is not None:
|
|
raise NotImplementedError(
|
|
"Past key value states are not implemented for Idefics cross attention module."
|
|
)
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.cross_attn(
|
|
hidden_states=hidden_states,
|
|
key_value_states=image_hidden_states,
|
|
attention_mask=image_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
|
# when there are no images the model is used in pure language mode
|
|
gate = 0 if no_images else 1
|
|
hidden_states = (
|
|
residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
|
)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
|
hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
LLAMA_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`IdeficsConfig`]):
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
|
|
# @add_start_docstrings(
|
|
# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
# LLAMA_START_DOCSTRING,
|
|
# )
|
|
class IdeficsPreTrainedModel(PreTrainedModel):
|
|
config_class = IdeficsConfig
|
|
# base_model_prefix = "model"
|
|
# supports_gradient_checkpointing = True
|
|
# _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
|
|
|
# def _init_weights(self, module):
|
|
# # important: this ported version of Idefics isn't meant for training from scratch - only
|
|
# # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
|
|
# # base should be used for training from scratch and it contains the correct code.
|
|
# std = self.config.initializer_range
|
|
# if isinstance(module, nn.Linear):
|
|
# module.weight.data.normal_(mean=0.0, std=std)
|
|
# if module.bias is not None:
|
|
# module.bias.data.zero_()
|
|
# elif isinstance(module, nn.Embedding):
|
|
# module.weight.data.normal_(mean=0.0, std=std)
|
|
# if module.padding_idx is not None:
|
|
# module.weight.data[module.padding_idx].zero_()
|
|
|
|
# def _set_gradient_checkpointing(self, module, value=False):
|
|
# if isinstance(module, IdeficsModel):
|
|
# module.gradient_checkpointing = value
|
|
|
|
|
|
# LLAMA_INPUTS_DOCSTRING = r"""
|
|
# Args:
|
|
# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
# it.
|
|
|
|
# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
# [`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
# [What are input IDs?](../glossary#input-ids)
|
|
# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
# - 1 for tokens that are **not masked**,
|
|
# - 0 for tokens that are **masked**.
|
|
|
|
# [What are attention masks?](../glossary#attention-mask)
|
|
|
|
# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
# [`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
# `past_key_values`).
|
|
|
|
# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
# information on the default strategy.
|
|
|
|
# - 1 indicates the head is **not masked**,
|
|
# - 0 indicates the head is **masked**.
|
|
# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
|
# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
|
|
# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
|
# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
# `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
# model's internal embedding lookup matrix.
|
|
# use_cache (`bool`, *optional*):
|
|
# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
# `past_key_values`).
|
|
# output_attentions (`bool`, *optional*):
|
|
# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
# tensors for more detail.
|
|
# output_hidden_states (`bool`, *optional*):
|
|
# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
# more detail.
|
|
# return_dict (`bool`, *optional*):
|
|
# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
# """
|
|
|
|
|
|
# @add_start_docstrings(
|
|
# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
# LLAMA_START_DOCSTRING,
|
|
# )
|
|
class IdeficsModel(IdeficsPreTrainedModel):
|
|
# """
|
|
# Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
|
|
|
|
# Args:
|
|
# config: IdeficsConfig
|
|
# """
|
|
|
|
def __init__(self, config: IdeficsConfig, weights):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = IdeficsDecoupledPartialTPEmbedding(
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
|
|
self.image_size = config.vision_config.image_size
|
|
self.vision_config = config.vision_config
|
|
self.vision_model = IdeficsVisionTransformer(
|
|
prefix="model.vision_model",
|
|
config=config.vision_config,
|
|
weights=weights,
|
|
)
|
|
|
|
# Perceiver Resampler
|
|
if config.use_resampler:
|
|
perceiver_config = config.perceiver_config
|
|
self.perceiver_resampler = IdeficsPerceiverResampler(
|
|
prefix="model.perceiver_resampler",
|
|
config=config,
|
|
embed_dim=config.vision_config.embed_dim,
|
|
depth=perceiver_config.resampler_depth,
|
|
n_heads=perceiver_config.resampler_n_heads,
|
|
head_dim=perceiver_config.resampler_head_dim,
|
|
n_latents=perceiver_config.resampler_n_latents,
|
|
weights=weights,
|
|
)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
IdeficsDecoderLayer(layer_id, config, weights)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
self.cross_layer_interval = config.cross_layer_interval
|
|
num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
|
|
self.gated_cross_attn_layers = nn.ModuleList(
|
|
[
|
|
IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
|
|
for layer_id in range(num_cross_layers)
|
|
]
|
|
)
|
|
# self.gradient_checkpointing = False
|
|
|
|
self.norm = IdeficsRMSNorm(
|
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
|
)
|
|
|
|
# self.gradient_checkpointing = False
|
|
# Initialize weights and apply final processing
|
|
# self.post_init()
|
|
|
|
# self.freeze_relevant_params(config)
|
|
|
|
# def freeze_relevant_params(self, config=None):
|
|
# if config is None:
|
|
# config = self.config
|
|
|
|
# if config.freeze_text_layers:
|
|
# self.freeze_text_layers(config.freeze_text_module_exceptions)
|
|
|
|
# if config.freeze_vision_layers:
|
|
# freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
|
|
|
|
# def freeze_text_layers(self, module_exceptions=[]):
|
|
# for module in [self.layers, self.norm]:
|
|
# freeze_model(module, module_exceptions=module_exceptions)
|
|
|
|
# def freeze_vision_layers(self, module_exceptions=[]):
|
|
# freeze_model(self.vision_model, module_exceptions=module_exceptions)
|
|
|
|
# def get_input_embeddings(self):
|
|
# return self.embed_tokens
|
|
|
|
# def set_input_embeddings(self, value):
|
|
# self.embed_tokens = value
|
|
|
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
|
def _prepare_decoder_attention_mask(
|
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
):
|
|
# create causal mask
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
combined_attention_mask = None
|
|
if input_shape[-1] > 1:
|
|
combined_attention_mask = _make_causal_mask(
|
|
input_shape,
|
|
inputs_embeds.dtype,
|
|
device=inputs_embeds.device,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
expanded_attn_mask = _expand_mask(
|
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
).to(inputs_embeds.device)
|
|
combined_attention_mask = (
|
|
expanded_attn_mask
|
|
if combined_attention_mask is None
|
|
else expanded_attn_mask + combined_attention_mask
|
|
)
|
|
|
|
return combined_attention_mask
|
|
|
|
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
|
image_attention_mask: Optional[torch.Tensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPastImage]:
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
)
|
|
elif input_ids is not None:
|
|
batch_size, seq_length = input_ids.shape
|
|
elif inputs_embeds is not None:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
else:
|
|
raise ValueError(
|
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
)
|
|
|
|
seq_length_with_past = seq_length
|
|
past_key_values_length = 0
|
|
|
|
if past_key_values is not None:
|
|
past_key_values_length = past_key_values[0][0].shape[2]
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
# create position_ids on the fly for batch generation
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
elif position_ids is None:
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
position_ids = torch.arange(
|
|
past_key_values_length,
|
|
seq_length + past_key_values_length,
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
else:
|
|
position_ids = position_ids.view(-1, seq_length).long()
|
|
|
|
no_images = False
|
|
|
|
if image_hidden_states is None:
|
|
if pixel_values is None and image_embeddings is None:
|
|
raise ValueError(
|
|
"Either pixel_values and image_embeddings have to be not-None."
|
|
)
|
|
|
|
elif pixel_values is not None and image_embeddings is not None:
|
|
raise ValueError(
|
|
"You cannot specify both pixel_values and image_embeddings at the same time"
|
|
)
|
|
|
|
elif pixel_values is not None:
|
|
no_images = len(torch.nonzero(pixel_values)) == 0
|
|
pixel_values = pixel_values.to(
|
|
dtype=self.dtype, device=device
|
|
) # fp16 compatibility
|
|
batch_size, num_images = pixel_values.shape[:2]
|
|
pixel_values = pixel_values.contiguous().view(
|
|
batch_size * num_images, *pixel_values.shape[2:]
|
|
)
|
|
|
|
# Get sequence from the vision encoder
|
|
image_hidden_states = self.vision_model(
|
|
pixel_values=pixel_values
|
|
).last_hidden_state
|
|
|
|
elif image_embeddings is not None:
|
|
(
|
|
batch_size,
|
|
num_images,
|
|
image_seq_len,
|
|
image_hidden_size,
|
|
) = image_embeddings.size()
|
|
image_hidden_states = image_embeddings.to(
|
|
dtype=self.dtype, device=input_ids.device
|
|
)
|
|
image_hidden_states = image_hidden_states.view(
|
|
batch_size * num_images, image_seq_len, image_hidden_size
|
|
)
|
|
|
|
if self.config.use_resampler:
|
|
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
|
image_seq_len, image_hidden_size = image_hidden_states.size(
|
|
1
|
|
), image_hidden_states.size(2)
|
|
image_hidden_states = image_hidden_states.view(
|
|
batch_size, num_images * image_seq_len, image_hidden_size
|
|
)
|
|
else:
|
|
no_images = False
|
|
num_images = pixel_values.shape[1]
|
|
image_seq_len = image_hidden_states.shape[1] // num_images
|
|
|
|
# # Hack to use the model in full language modeling mode
|
|
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
|
|
# Make image_attention_mask compatible with hidden states
|
|
text_seq_len = image_attention_mask.size(1)
|
|
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
|
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
|
image_attention_mask = image_attention_mask.view(
|
|
batch_size, text_seq_len, num_images * image_seq_len
|
|
)
|
|
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
|
image_hidden_shape = (image_batch_size, image_sequence_length)
|
|
if image_attention_mask is None:
|
|
image_attention_mask = torch.ones(image_hidden_shape, device=device)
|
|
image_attention_mask = self.invert_attention_mask(image_attention_mask)
|
|
|
|
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
|
|
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
|
|
|
|
# if image_hidden_states is not None:
|
|
# else:
|
|
# image_attention_mask = None
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
# embed positions
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
(batch_size, seq_length_with_past),
|
|
dtype=torch.bool,
|
|
device=inputs_embeds.device,
|
|
)
|
|
attention_mask = self._prepare_decoder_attention_mask(
|
|
attention_mask,
|
|
(batch_size, seq_length),
|
|
inputs_embeds,
|
|
past_key_values_length,
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# if self.gradient_checkpointing and self.training:
|
|
# if use_cache:
|
|
# logger.warning_once(
|
|
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
# )
|
|
# use_cache = False
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
past_key_value = (
|
|
past_key_values[idx] if past_key_values is not None else None
|
|
)
|
|
|
|
def vblock(
|
|
main_block,
|
|
hidden_states,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_value,
|
|
image_hidden_states,
|
|
image_attention_mask,
|
|
output_attentions,
|
|
use_cache,
|
|
no_images,
|
|
layer_idx,
|
|
cross_layer_interval,
|
|
gated_cross_attn_layers,
|
|
):
|
|
# TODO(ls): Add cross attention values to respective lists
|
|
if layer_idx % cross_layer_interval == 0:
|
|
xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
|
|
outputs = xblock(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
image_hidden_states=image_hidden_states,
|
|
image_attention_mask=image_attention_mask,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
past_key_value=None, # not implemented
|
|
no_images=no_images,
|
|
)
|
|
hidden_states = outputs[0]
|
|
|
|
layer_outputs = main_block(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
return layer_outputs
|
|
|
|
# if self.gradient_checkpointing and self.training:
|
|
# past_key_value = None
|
|
# if use_cache:
|
|
# logger.warning_once(
|
|
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
# )
|
|
# use_cache = False
|
|
|
|
# layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
# vblock,
|
|
# decoder_layer,
|
|
# hidden_states,
|
|
# attention_mask,
|
|
# position_ids,
|
|
# past_key_value,
|
|
# image_hidden_states,
|
|
# image_attention_mask,
|
|
# output_attentions,
|
|
# use_cache,
|
|
# no_images,
|
|
# idx,
|
|
# self.cross_layer_interval,
|
|
# self.gated_cross_attn_layers,
|
|
# )
|
|
# else:
|
|
layer_outputs = vblock(
|
|
decoder_layer,
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
image_hidden_states=image_hidden_states,
|
|
image_attention_mask=image_attention_mask,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
no_images=no_images,
|
|
layer_idx=idx,
|
|
cross_layer_interval=self.cross_layer_interval,
|
|
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPastImage(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
image_hidden_states=image_hidden_states,
|
|
)
|
|
|
|
|
|
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
weights,
|
|
):
|
|
super().__init__(config)
|
|
self.model = IdeficsModel(
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
|
|
self.lm_head = IdeficsDecoupledTensorParallelLinear(
|
|
config=config,
|
|
weights=weights,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
|
image_attention_mask: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPastImage]:
|
|
r"""
|
|
Args:
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
|
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
|
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
|
```"""
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
pixel_values=pixel_values,
|
|
image_embeddings=image_embeddings,
|
|
image_hidden_states=image_hidden_states,
|
|
image_attention_mask=image_attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
logits, speculative_logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
|
|
return (
|
|
CausalLMOutputWithPastImage(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=outputs.image_hidden_states,
|
|
),
|
|
speculative_logits,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
|
|
unwanted_kwargs = ["token_type_ids"]
|
|
for kwarg in unwanted_kwargs:
|
|
inputs.pop(kwarg, None)
|
|
return inputs
|
|
|
|
@staticmethod
|
|
def _expand_inputs_for_generation(
|
|
*args,
|
|
**model_kwargs,
|
|
):
|
|
return expand_inputs_for_generation(*args, **model_kwargs)
|
|
|
|
@staticmethod
|
|
def _update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=False
|
|
):
|
|
return update_model_kwargs_for_generation(
|
|
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
|
|
)
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past, beam_idx):
|
|
reordered_past = ()
|
|
for layer_past in past:
|
|
reordered_past += (
|
|
tuple(
|
|
past_state.index_select(0, beam_idx) for past_state in layer_past
|
|
),
|
|
)
|
|
return reordered_past
|