2023-08-17 12:38:49 +00:00
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics model. """
from typing import List , Optional , Tuple , Union
import torch
import torch . nn . functional as F
import torch . utils . checkpoint
from torch import nn
from torch . nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers . activations import ACT2FN
2023-09-27 10:22:09 +00:00
from transformers . modeling_outputs import (
BaseModelOutputWithPast ,
CausalLMOutputWithPast ,
dataclass ,
)
2023-08-17 12:38:49 +00:00
from transformers . modeling_utils import PretrainedConfig
from transformers . utils import (
add_start_docstrings ,
add_start_docstrings_to_model_forward ,
logging ,
replace_return_docstrings ,
)
from text_generation_server . models . custom_modeling . idefics_config import IdeficsConfig
2023-09-27 10:22:09 +00:00
from text_generation_server . models . custom_modeling . idefics_vision import (
IdeficsVisionTransformer ,
)
from text_generation_server . models . custom_modeling . idefics_perceiver import (
IdeficsPerceiverResampler ,
)
2023-08-17 12:38:49 +00:00
from text_generation_server . utils . layers import (
TensorParallelColumnLinear ,
TensorParallelEmbedding ,
TensorParallelRowLinear ,
2024-02-26 18:49:28 +00:00
SpeculativeHead ,
2023-08-17 12:38:49 +00:00
PositionRotaryEmbedding ,
FastLinear ,
)
2023-11-27 13:08:12 +00:00
from text_generation_server . utils . import_utils import IS_CUDA_SYSTEM , IS_ROCM_SYSTEM
2023-08-17 12:38:49 +00:00
2023-11-27 13:08:12 +00:00
if IS_CUDA_SYSTEM :
import dropout_layer_norm
elif IS_ROCM_SYSTEM :
from vllm import layernorm_ops
2023-09-27 10:22:09 +00:00
2023-12-11 13:49:52 +00:00
2023-08-17 12:38:49 +00:00
@dataclass
class BaseModelOutputWithPastImage ( BaseModelOutputWithPast ) :
image_hidden_states : Optional [ torch . FloatTensor ] = None
2023-09-27 10:22:09 +00:00
2023-08-17 12:38:49 +00:00
@dataclass
class CausalLMOutputWithPastImage ( CausalLMOutputWithPast ) :
image_hidden_states : Optional [ torch . FloatTensor ] = None
# logger = logging.get_logger(__name__)
# _CONFIG_FOR_DOC = "IdeficsConfig"
# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [
# "HuggingFaceM4/idefics-9b",
# "HuggingFaceM4/idefics-80b",
# # See all Idefics models at https://huggingface.co/models?filter=idefics
# ]
def expand_inputs_for_generation (
input_ids ,
expand_size = 1 ,
is_encoder_decoder = False ,
attention_mask = None ,
encoder_outputs = None ,
* * model_kwargs ,
) :
expanded_return_idx = (
2023-09-27 10:22:09 +00:00
torch . arange ( input_ids . shape [ 0 ] )
. view ( - 1 , 1 )
. repeat ( 1 , expand_size )
. view ( - 1 )
. to ( input_ids . device )
2023-08-17 12:38:49 +00:00
)
input_ids = input_ids . index_select ( 0 , expanded_return_idx )
if " token_type_ids " in model_kwargs :
token_type_ids = model_kwargs [ " token_type_ids " ]
2023-09-27 10:22:09 +00:00
model_kwargs [ " token_type_ids " ] = token_type_ids . index_select (
0 , expanded_return_idx
)
2023-08-17 12:38:49 +00:00
if attention_mask is not None :
2023-09-27 10:22:09 +00:00
model_kwargs [ " attention_mask " ] = attention_mask . index_select (
0 , expanded_return_idx
)
model_kwargs [ " image_attention_mask " ] = model_kwargs [
" image_attention_mask "
] . index_select ( 0 , expanded_return_idx )
model_kwargs [ " pixel_values " ] = model_kwargs [ " pixel_values " ] . index_select (
2023-08-17 12:38:49 +00:00
0 , expanded_return_idx
)
if is_encoder_decoder :
if encoder_outputs is None :
2023-09-27 10:22:09 +00:00
raise ValueError (
" If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined. "
)
2024-02-16 10:58:58 +00:00
encoder_outputs [ " last_hidden_state " ] = (
encoder_outputs . last_hidden_state . index_select (
0 , expanded_return_idx . to ( encoder_outputs . last_hidden_state . device )
)
2023-08-17 12:38:49 +00:00
)
model_kwargs [ " encoder_outputs " ] = encoder_outputs
return input_ids , model_kwargs
def update_model_kwargs_for_generation ( outputs , model_kwargs , is_encoder_decoder = False ) :
# must have this key set to at least None
model_kwargs [ " past_key_values " ] = model_kwargs . get ( " past_key_values " , None )
# update past
if " past_key_values " in outputs :
model_kwargs [ " past " ] = outputs . past_key_values
elif " mems " in outputs :
model_kwargs [ " past " ] = outputs . mems
elif " past_buckets_states " in outputs :
model_kwargs [ " past " ] = outputs . past_buckets_states
else :
model_kwargs [ " past " ] = None
# update token_type_ids with last value
if " token_type_ids " in model_kwargs :
token_type_ids = model_kwargs [ " token_type_ids " ]
2023-09-27 10:22:09 +00:00
model_kwargs [ " token_type_ids " ] = torch . cat (
[ token_type_ids , token_type_ids [ : , - 1 ] . unsqueeze ( - 1 ) ] , dim = - 1
)
2023-08-17 12:38:49 +00:00
# update attention masks
if not is_encoder_decoder :
if " attention_mask " in model_kwargs :
attention_mask = model_kwargs [ " attention_mask " ]
model_kwargs [ " attention_mask " ] = torch . cat (
2023-09-27 10:22:09 +00:00
[ attention_mask , attention_mask . new_ones ( ( attention_mask . shape [ 0 ] , 1 ) ) ] ,
dim = - 1 ,
2023-08-17 12:38:49 +00:00
)
if " image_attention_mask " in model_kwargs :
image_attention_mask = model_kwargs [ " image_attention_mask " ]
last_mask = image_attention_mask [ : , - 1 , : ] . unsqueeze ( 1 )
model_kwargs [ " image_attention_mask " ] = last_mask
return model_kwargs
def prepare_inputs_for_generation ( input_ids , past = None , * * kwargs ) :
token_type_ids = kwargs . get ( " token_type_ids " , None )
# only last token for inputs_ids if past is defined in kwargs
if past :
input_ids = input_ids [ : , - 1 ] . unsqueeze ( - 1 )
if token_type_ids is not None :
token_type_ids = token_type_ids [ : , - 1 ] . unsqueeze ( - 1 )
attention_mask = kwargs . get ( " attention_mask " , None )
position_ids = kwargs . get ( " position_ids " , None )
if attention_mask is not None and position_ids is None :
# create position_ids on the fly for batch generation
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
if past :
position_ids = position_ids [ : , - 1 ] . unsqueeze ( - 1 )
pixel_values = kwargs . get ( " pixel_values " , None )
image_attention_mask = kwargs . get ( " image_attention_mask " , None )
# if pixel_values is None or image_attention_mask is None:
# raise ValueError("pixel values and image attention mask cannot be None")
return {
" input_ids " : input_ids ,
" past_key_values " : past ,
" use_cache " : kwargs . get ( " use_cache " ) ,
" position_ids " : position_ids ,
" attention_mask " : attention_mask ,
" token_type_ids " : token_type_ids ,
" pixel_values " : pixel_values ,
" image_attention_mask " : image_attention_mask ,
}
def freeze_model ( model , module_exceptions = [ ] ) :
mapping = {
" LayerNorm " : nn . LayerNorm ,
" Linear " : nn . Linear ,
" Embedding " : nn . Embedding ,
}
module_exceptions_mapped = [ mapping [ m ] for m in module_exceptions ]
for module in model . modules ( ) :
2023-09-27 10:22:09 +00:00
if module_exceptions and any (
[ isinstance ( module , t ) for t in module_exceptions_mapped ]
) :
module . requires_grad_ (
True
) # Explicitely setting it to true to avoid any mistakes
2023-08-17 12:38:49 +00:00
else :
module . requires_grad_ ( False )
return model
class IdeficsDecoupledPartialTPEmbedding ( nn . Module ) :
def __init__ (
self ,
config ,
weights ,
) :
super ( ) . __init__ ( )
self . num_embeddings = config . vocab_size
2023-09-27 10:22:09 +00:00
self . weight = TensorParallelEmbedding (
prefix = " model.embed_tokens " , weights = weights
)
self . additional_weight = nn . Parameter (
weights . get_tensor ( f " model.embed_tokens.additional_embedding.weight " )
)
2023-08-17 12:38:49 +00:00
def forward ( self , input_ids ) :
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids . clone ( )
additional_vocab_indices = torch . where ( input_ids > = self . num_embeddings )
input_ids_additional_vocab = input_ids [ additional_vocab_indices ]
2023-09-27 10:22:09 +00:00
additional_embeddings = torch . nn . functional . embedding (
input_ids_additional_vocab - self . num_embeddings , self . additional_weight
)
2023-08-17 12:38:49 +00:00
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids [ additional_vocab_indices ] = 0
full_vector = self . weight ( input_ids )
# overwrite the records with high indices
full_vector [ additional_vocab_indices ] = additional_embeddings
return full_vector
class IdeficsDecoupledTensorParallelLinear ( nn . Module ) :
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
"""
Implements a decoupling of parameters to allow freezing ( or not ) a subset of the parameters . In practise , the
regular ` weight ` can be trained or frozen ( i . e . ` partially_freeze = True ` ) , and if ` out_additional_features ` > 0 ,
then it will create ` out_additional_features * in_features ` additional parameters that are always trained . If
` out_additional_features = 0 ` , then the module defaults back to the regular behavior of ` nn . Linear ` .
"""
def __init__ (
self ,
config ,
weights ,
) - > None :
super ( ) . __init__ ( )
2024-02-26 18:49:28 +00:00
self . fc = SpeculativeHead . load ( config = config , prefix = " lm_head " , weights = weights )
2023-08-17 12:38:49 +00:00
self . additional_fc = FastLinear . load (
2023-09-27 10:22:09 +00:00
config = config ,
prefix = " lm_head.additional_fc " ,
weights = weights ,
bias = False ,
2023-08-17 12:38:49 +00:00
)
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
2024-02-26 18:49:28 +00:00
output , speculative_logits = self . fc ( input )
2023-08-17 12:38:49 +00:00
additional_features = self . additional_fc ( input )
output = torch . cat ( ( output , additional_features ) , - 1 )
2024-02-26 18:49:28 +00:00
return output , speculative_logits
2023-08-17 12:38:49 +00:00
def extra_repr ( self ) - > str :
""" Overwriting `nn.Linear.extra_repr` to include new parameters. """
return " in_features= {} , out_features= {} , out_additional_features= {} , bias= {} , partially_freeze= {} " . format (
self . in_features ,
self . out_features ,
self . out_additional_features ,
self . bias is not None ,
self . partially_freeze ,
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask (
2023-09-27 10:22:09 +00:00
input_ids_shape : torch . Size ,
dtype : torch . dtype ,
device : torch . device ,
past_key_values_length : int = 0 ,
2023-08-17 12:38:49 +00:00
) :
"""
Make causal mask used for bi - directional self - attention .
"""
bsz , tgt_len = input_ids_shape
mask = torch . full ( ( tgt_len , tgt_len ) , torch . finfo ( dtype ) . min , device = device )
mask_cond = torch . arange ( mask . size ( - 1 ) , device = device )
mask . masked_fill_ ( mask_cond < ( mask_cond + 1 ) . view ( mask . size ( - 1 ) , 1 ) , 0 )
mask = mask . to ( dtype )
if past_key_values_length > 0 :
2023-09-27 10:22:09 +00:00
mask = torch . cat (
[
torch . zeros (
tgt_len , past_key_values_length , dtype = dtype , device = device
) ,
mask ,
] ,
dim = - 1 ,
)
return mask [ None , None , : , : ] . expand (
bsz , 1 , tgt_len , tgt_len + past_key_values_length
)
2023-08-17 12:38:49 +00:00
def _expand_mask ( mask : torch . Tensor , dtype : torch . dtype , tgt_len : Optional [ int ] = None ) :
"""
Expands attention_mask from ` [ bsz , seq_len ] ` to ` [ bsz , 1 , tgt_seq_len , src_seq_len ] ` .
"""
bsz , src_len = mask . size ( )
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask [ : , None , None , : ] . expand ( bsz , 1 , tgt_len , src_len ) . to ( dtype )
inverted_mask = 1.0 - expanded_mask
2023-09-27 10:22:09 +00:00
return inverted_mask . masked_fill (
inverted_mask . to ( torch . bool ) , torch . finfo ( dtype ) . min
)
2023-08-17 12:38:49 +00:00
class IdeficsRMSNorm ( nn . Module ) :
def __init__ ( self , prefix , weights , eps = 1e-6 ) :
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super ( ) . __init__ ( )
weight = weights . get_tensor ( f " { prefix } .weight " )
self . weight = nn . Parameter ( weight )
self . variance_epsilon = eps
def forward ( self , hidden_states , residual = None ) :
if hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt (
variance + self . variance_epsilon
)
# convert into half-precision if necessary
if self . weight . dtype in [ torch . float16 , torch . bfloat16 ] :
hidden_states = hidden_states . to ( self . weight . dtype )
return self . weight * hidden_states
2023-11-27 13:08:12 +00:00
elif IS_CUDA_SYSTEM :
2023-08-17 12:38:49 +00:00
# faster post attention rms norm
unwrap = False
if len ( hidden_states . shape ) > 2 :
unwrap = True
shape = hidden_states . shape
hidden_states = hidden_states . reshape ( - 1 , shape [ - 1 ] )
normed_hidden_states , res , * rest = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
None ,
None ,
None ,
None ,
None ,
0.0 ,
self . variance_epsilon ,
1.0 ,
0 ,
None ,
False ,
True , # Activate RMSNorm
)
if res is None :
res = hidden_states
if unwrap :
normed_hidden_states = normed_hidden_states . view ( * shape )
return normed_hidden_states
2023-11-27 13:08:12 +00:00
elif IS_ROCM_SYSTEM :
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None :
hidden_states + = residual
residual = hidden_states
unwrap = False
if len ( hidden_states . shape ) > 2 :
unwrap = True
shape = hidden_states . shape
hidden_states = hidden_states . reshape ( - 1 , shape [ - 1 ] )
out = torch . empty_like ( hidden_states )
layernorm_ops . rms_norm (
out ,
hidden_states ,
self . weight . data ,
self . variance_epsilon ,
)
if unwrap :
out = out . view ( * shape )
return out
else :
2023-12-11 13:49:52 +00:00
raise ValueError (
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)
2023-08-17 12:38:49 +00:00
# this was adapted from LlamaMLP
class IdeficsMLP ( nn . Module ) :
def __init__ (
self ,
config ,
prefix ,
weights ,
) :
super ( ) . __init__ ( )
self . gate_up_proj = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .gate_proj " , f " { prefix } .up_proj " ] ,
weights = weights ,
dim = 0 ,
bias = False ,
)
self . down_proj = TensorParallelRowLinear . load (
2023-09-27 10:22:09 +00:00
config ,
prefix = f " { prefix } .down_proj " ,
weights = weights ,
bias = False ,
2023-08-17 12:38:49 +00:00
)
self . act_fn = ACT2FN [ config . hidden_act ]
def forward ( self , hidden_states ) :
gate_up_states = self . gate_up_proj ( hidden_states )
shape = gate_up_states . shape
gate_up_states = gate_up_states . view ( * shape [ : - 1 ] , 2 , shape [ - 1 ] / / 2 )
2023-09-27 10:22:09 +00:00
return self . down_proj (
self . act_fn ( gate_up_states [ : , : , 0 ] ) * gate_up_states [ : , : , 1 ]
)
2023-08-17 12:38:49 +00:00
# this was adapted from LlamaAttention
class IdeficsAttention ( nn . Module ) :
""" Multi-headed attention from ' Attention Is All You Need ' paper """
def __init__ (
self ,
config ,
prefix ,
weights ,
qk_layer_norms : bool = False ,
is_cross_attention : bool = False ,
) :
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
self . num_heads = config . num_attention_heads
self . head_dim = self . hidden_size / / self . num_heads
self . dropout = config . dropout
if ( self . head_dim * self . num_heads ) != self . hidden_size :
raise ValueError (
f " hidden_size must be divisible by num_heads (got `hidden_size`: { self . hidden_size } "
f " and `num_heads`: { self . num_heads } ). "
)
self . is_cross_attention = is_cross_attention
# if not hasattr(nn.functional, "scaled_dot_product_attention"):
# raise ValueError("this model requires pytorch 2.0 or higher")
process_group = weights . process_group
if self . num_heads % weights . process_group . size ( ) != 0 :
raise ValueError (
f " `num_heads` must be divisible by `num_shards` (got `num_heads`: { self . num_heads } "
f " and `num_shards`: { weights . process_group . size ( ) } "
)
self . num_heads / / = weights . process_group . size ( )
if self . is_cross_attention :
# kv_input_dim = (
# self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
# )
self . q_proj = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .q_proj " , weights = weights , bias = False
)
self . k_proj = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .k_proj " , weights = weights , bias = False
)
self . v_proj = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .v_proj " , weights = weights , bias = False
)
else :
self . qkv = TensorParallelColumnLinear . load_multi (
config ,
prefixes = [ f " { prefix } .q_proj " , f " { prefix } .k_proj " , f " { prefix } .v_proj " ] ,
dim = 0 ,
weights = weights ,
bias = False ,
)
self . o_proj = TensorParallelRowLinear . load (
config , prefix = f " { prefix } .o_proj " , weights = weights , bias = False
)
self . rotary_emb = PositionRotaryEmbedding . static (
config = config , dim = self . head_dim , base = 10000.0 , device = weights . device
)
self . qk_layer_norms = qk_layer_norms
if self . qk_layer_norms :
self . q_layer_norm = IdeficsRMSNorm (
2023-09-27 10:22:09 +00:00
prefix = f " { prefix } .q_layer_norm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
2023-08-17 12:38:49 +00:00
self . k_layer_norm = IdeficsRMSNorm (
2023-09-27 10:22:09 +00:00
prefix = f " { prefix } .q_layer_norm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
2023-08-17 12:38:49 +00:00
def _shape ( self , tensor : torch . Tensor , seq_len : int , bsz : int ) :
2023-09-27 10:22:09 +00:00
return (
tensor . view ( bsz , seq_len , self . num_heads , self . head_dim )
. transpose ( 1 , 2 )
. contiguous ( )
)
2023-08-17 12:38:49 +00:00
def forward (
self ,
hidden_states : torch . Tensor ,
key_value_states : Optional [ torch . Tensor ] = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
output_attentions : bool = False ,
use_cache : bool = False ,
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] , Optional [ Tuple [ torch . Tensor ] ] ] :
# if key_value_states are provided this layer is used as a cross-attention layer
is_cross_attention = self . is_cross_attention or key_value_states is not None
bsz , q_len , _ = hidden_states . size ( )
if is_cross_attention :
2023-09-27 10:22:09 +00:00
query_states = self . q_proj ( hidden_states ) . view (
bsz , q_len , self . num_heads , self . head_dim
) # .transpose(1, 2)
2023-08-17 12:38:49 +00:00
query_states = query_states . transpose ( 1 , 2 )
2023-09-27 10:22:09 +00:00
(
_ ,
kv_len ,
_ ,
) = (
key_value_states . size ( )
) # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = (
self . k_proj ( key_value_states )
. view ( bsz , kv_len , self . num_heads , self . head_dim )
. transpose ( 1 , 2 )
)
2023-08-17 12:38:49 +00:00
value_states = (
2023-09-27 10:22:09 +00:00
self . v_proj ( key_value_states )
. view ( bsz , kv_len , self . num_heads , self . head_dim )
. transpose ( 1 , 2 )
2023-08-17 12:38:49 +00:00
)
else :
qkv = self . qkv ( hidden_states )
2023-09-27 10:22:09 +00:00
query_states , key_states , value_states = qkv . split (
self . num_heads * self . head_dim , dim = 2
)
2023-08-17 12:38:49 +00:00
2023-09-27 10:22:09 +00:00
query_states = query_states . view (
bsz , q_len , self . num_heads , self . head_dim
) # .transpose(1, 2)
key_states = key_states . view (
bsz , q_len , self . num_heads , self . head_dim
) # . transpose(1, 2)
value_states = value_states . view (
bsz , q_len , self . num_heads , self . head_dim
) # .transpose(1, 2)
2023-08-17 12:38:49 +00:00
kv_seq_len = q_len
if past_key_value is not None :
kv_seq_len + = past_key_value [ 0 ] . shape [ - 2 ]
max_s = max ( kv_seq_len , q_len )
cos , sin = self . rotary_emb . get_cos_sin (
position_ids . view ( - 1 ) , max_s , hidden_states . dtype
)
2023-11-27 13:08:12 +00:00
query_shape = query_states . shape
key_shape = key_states . shape
2023-12-11 13:49:52 +00:00
self . rotary_emb (
query_states . view ( - 1 , * query_shape [ 2 : ] ) ,
key_states . reshape ( - 1 , * key_shape [ 2 : ] ) ,
cos ,
sin ,
)
2023-11-27 13:08:12 +00:00
query_states = query_states . view ( query_shape )
key_states = key_states . view ( key_shape )
2023-08-17 12:38:49 +00:00
query_states = query_states . transpose ( 1 , 2 )
key_states = key_states . transpose ( 1 , 2 )
value_states = value_states . transpose ( 1 , 2 )
kv_seq_len = key_states . shape [ - 2 ]
if past_key_value is not None :
kv_seq_len + = past_key_value [ 0 ] . shape [ - 2 ]
# [bsz, nh, t, hd]
if past_key_value is not None :
# reuse k, v, self_attention
key_states = torch . cat ( [ past_key_value [ 0 ] , key_states ] , dim = 2 )
value_states = torch . cat ( [ past_key_value [ 1 ] , value_states ] , dim = 2 )
past_key_value = ( key_states , value_states ) if use_cache else None
if self . qk_layer_norms :
query_states = self . q_layer_norm ( query_states )
key_states = self . k_layer_norm ( key_states )
if attention_mask is not None :
if attention_mask . size ( ) != ( bsz , 1 , q_len , kv_seq_len ) :
raise ValueError (
f " Attention mask should be of size { ( bsz , 1 , q_len , kv_seq_len ) } , but is { attention_mask . size ( ) } "
)
attn_output = nn . functional . scaled_dot_product_attention (
query_states ,
key_states ,
value_states ,
attn_mask = attention_mask ,
dropout_p = self . dropout ,
)
if attn_output . size ( ) != ( bsz , self . num_heads , q_len , self . head_dim ) :
raise ValueError (
f " `attn_output` should be of size { ( bsz , self . num_heads , q_len , self . head_dim ) } , but is "
f " { attn_output . size ( ) } "
)
attn_output = attn_output . transpose ( 1 , 2 )
attn_output = attn_output . reshape ( bsz , q_len , - 1 )
attn_output = self . o_proj ( attn_output )
attn_weights = None
if output_attentions :
logger . warning_once (
" attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead "
)
return attn_output , attn_weights , past_key_value
# this was adapted from LlamaDecoderLayer
class IdeficsDecoderLayer ( nn . Module ) :
def __init__ ( self , layer_id : int , config : IdeficsConfig , weights ) :
super ( ) . __init__ ( )
self . process_group = weights . process_group
self . hidden_size = config . hidden_size
prefix = f " model.layers. { layer_id } "
self . self_attn = IdeficsAttention (
config = config ,
prefix = f " { prefix } .self_attn " ,
weights = weights ,
qk_layer_norms = False ,
is_cross_attention = False ,
)
self . mlp = IdeficsMLP (
config = config ,
prefix = f " { prefix } .mlp " ,
weights = weights ,
)
2023-09-27 10:22:09 +00:00
self . input_layernorm = IdeficsRMSNorm (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
self . post_attention_layernorm = IdeficsRMSNorm (
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
2023-08-17 12:38:49 +00:00
self . dropout = config . dropout
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
2023-09-27 10:22:09 +00:00
) - > Tuple [
torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ]
] :
2023-08-17 12:38:49 +00:00
"""
Args :
hidden_states ( ` torch . FloatTensor ` ) : input to the layer of shape ` ( batch , seq_len , embed_dim ) `
attention_mask ( ` torch . FloatTensor ` , * optional * ) : attention mask of size
` ( batch , 1 , tgt_len , src_len ) ` where padding elements are indicated by very large negative values .
output_attentions ( ` bool ` , * optional * ) :
Whether or not to return the attentions tensors of all attention layers . See ` attentions ` under
returned tensors for more detail .
use_cache ( ` bool ` , * optional * ) :
If set to ` True ` , ` past_key_values ` key value states are returned and can be used to speed up decoding
( see ` past_key_values ` ) .
past_key_value ( ` Tuple ( torch . FloatTensor ) ` , * optional * ) : cached past key and value projection states
"""
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
# Self Attention
hidden_states , self_attn_weights , present_key_value = self . self_attn (
hidden_states = hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_value = past_key_value ,
output_attentions = output_attentions ,
use_cache = use_cache ,
)
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = ( hidden_states , )
if output_attentions :
outputs + = ( self_attn_weights , )
if use_cache :
outputs + = ( present_key_value , )
return outputs
class IdeficsGatedCrossAttentionLayer ( nn . Module ) :
def __init__ ( self , layer_id , config : IdeficsConfig , weights ) :
super ( ) . __init__ ( )
self . process_group = weights . process_group
self . hidden_size = config . hidden_size
prefix = f " model.gated_cross_attn_layers. { layer_id } "
self . cross_attn = IdeficsAttention (
config = config ,
prefix = f " { prefix } .cross_attn " ,
weights = weights ,
qk_layer_norms = True ,
is_cross_attention = True ,
)
self . mlp = IdeficsMLP (
config = config ,
prefix = f " { prefix } .mlp " ,
weights = weights ,
)
2023-09-27 10:22:09 +00:00
self . input_layernorm = IdeficsRMSNorm (
prefix = f " { prefix } .input_layernorm " , weights = weights , eps = config . rms_norm_eps
)
self . post_attention_layernorm = IdeficsRMSNorm (
prefix = f " { prefix } .post_attention_layernorm " ,
weights = weights ,
eps = config . rms_norm_eps ,
)
2023-08-17 12:38:49 +00:00
self . config = config . dropout
self . act_cross_attn = nn . Tanh ( )
self . act_dense = nn . Tanh ( )
2023-09-27 10:22:09 +00:00
self . alpha_cross_attn = nn . Parameter (
weights . get_tensor ( f " { prefix } .alpha_cross_attn " )
)
2023-08-17 12:38:49 +00:00
self . alpha_dense = nn . Parameter ( weights . get_tensor ( f " { prefix } .alpha_dense " ) )
if not ( hasattr ( self , " alpha_cross_attn " ) and hasattr ( self , " alpha_dense " ) ) :
raise ValueError ( " Alpha parameters not initialized correctly! " )
def forward (
self ,
hidden_states : torch . Tensor ,
attention_mask : Optional [ torch . Tensor ] = None ,
image_hidden_states : Optional [ torch . Tensor ] = None ,
image_attention_mask : Optional [ torch . Tensor ] = None ,
output_attentions : Optional [ bool ] = False ,
use_cache : Optional [ bool ] = False ,
past_key_value : Optional [ Tuple [ torch . Tensor ] ] = None ,
no_images : Optional [ bool ] = False ,
2023-09-27 10:22:09 +00:00
) - > Tuple [
torch . FloatTensor , Optional [ Tuple [ torch . FloatTensor , torch . FloatTensor ] ]
] :
2023-08-17 12:38:49 +00:00
"""
Args :
hidden_states ( ` torch . FloatTensor ` ) : input to the layer of shape ` ( batch , seq_len , embed_dim ) `
attention_mask ( ` torch . FloatTensor ` , * optional * ) : attention mask of size
` ( batch , 1 , tgt_len , src_len ) ` where padding elements are indicated by very large negative values .
output_attentions ( ` bool ` , * optional * ) :
Whether or not to return the attentions tensors of all attention layers . See ` attentions ` under
returned tensors for more detail .
use_cache ( ` bool ` , * optional * ) :
If set to ` True ` , ` past_key_values ` key value states are returned and can be used to speed up decoding
( see ` past_key_values ` ) .
past_key_value ( ` Tuple ( torch . FloatTensor ) ` , * optional * ) : cached past key and value projection states
no_images ( ` bool ` , * optional * , defaults to ` False ` ) : If ` True ` the vision part is ignored
"""
if image_hidden_states is None :
raise ValueError (
" `image_hidden_states` is required for Idefics cross attention module which are visual features to be "
" conditioned on. "
)
if past_key_value is not None :
2023-09-27 10:22:09 +00:00
raise NotImplementedError (
" Past key value states are not implemented for Idefics cross attention module. "
)
2023-08-17 12:38:49 +00:00
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
# Self Attention
hidden_states , self_attn_weights , present_key_value = self . cross_attn (
hidden_states = hidden_states ,
key_value_states = image_hidden_states ,
attention_mask = image_attention_mask ,
output_attentions = output_attentions ,
)
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# when there are no images the model is used in pure language mode
gate = 0 if no_images else 1
2023-09-27 10:22:09 +00:00
hidden_states = (
residual + gate * self . act_cross_attn ( self . alpha_cross_attn ) * hidden_states
)
2023-08-17 12:38:49 +00:00
# Fully Connected
residual = hidden_states
hidden_states = self . post_attention_layernorm ( hidden_states )
hidden_states = self . mlp ( hidden_states )
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
hidden_states = residual + self . act_dense ( self . alpha_dense ) * hidden_states
outputs = ( hidden_states , )
if output_attentions :
outputs + = ( self_attn_weights , )
if use_cache :
outputs + = ( present_key_value , )
return outputs
LLAMA_START_DOCSTRING = r """
This model inherits from [ ` PreTrainedModel ` ] . Check the superclass documentation for the generic methods the
library implements for all its model ( such as downloading or saving , resizing the input embeddings , pruning heads
etc . )
This model is also a PyTorch [ torch . nn . Module ] ( https : / / pytorch . org / docs / stable / nn . html #torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior .
Parameters :
config ( [ ` IdeficsConfig ` ] ) :
Model configuration class with all the parameters of the model . Initializing with a config file does not
load the weights associated with the model , only the configuration . Check out the
[ ` ~ PreTrainedModel . from_pretrained ` ] method to load the model weights .
"""
# @add_start_docstrings(
# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
# LLAMA_START_DOCSTRING,
# )
class IdeficsPreTrainedModel ( PreTrainedModel ) :
config_class = IdeficsConfig
# base_model_prefix = "model"
# supports_gradient_checkpointing = True
# _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
# def _init_weights(self, module):
# # important: this ported version of Idefics isn't meant for training from scratch - only
# # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
# # base should be used for training from scratch and it contains the correct code.
# std = self.config.initializer_range
# if isinstance(module, nn.Linear):
# module.weight.data.normal_(mean=0.0, std=std)
# if module.bias is not None:
# module.bias.data.zero_()
# elif isinstance(module, nn.Embedding):
# module.weight.data.normal_(mean=0.0, std=std)
# if module.padding_idx is not None:
# module.weight.data[module.padding_idx].zero_()
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, IdeficsModel):
# module.gradient_checkpointing = value
# LLAMA_INPUTS_DOCSTRING = r"""
# Args:
# input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
# Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
# it.
# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
# [`PreTrainedTokenizer.__call__`] for details.
# [What are input IDs?](../glossary#input-ids)
# attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
# Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
# - 1 for tokens that are **not masked**,
# - 0 for tokens that are **masked**.
# [What are attention masks?](../glossary#attention-mask)
# Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
# [`PreTrainedTokenizer.__call__`] for details.
# If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
# `past_key_values`).
# If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
# and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
# information on the default strategy.
# - 1 indicates the head is **not masked**,
# - 0 indicates the head is **masked**.
# position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
# Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
# config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
# past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
# Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
# `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
# `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
# Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
# blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
# If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
# don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
# `decoder_input_ids` of shape `(batch_size, sequence_length)`.
# inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
# Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
# is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
# model's internal embedding lookup matrix.
# use_cache (`bool`, *optional*):
# If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
# `past_key_values`).
# output_attentions (`bool`, *optional*):
# Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
# tensors for more detail.
# output_hidden_states (`bool`, *optional*):
# Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
# more detail.
# return_dict (`bool`, *optional*):
# Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
# """
# @add_start_docstrings(
# "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
# LLAMA_START_DOCSTRING,
# )
class IdeficsModel ( IdeficsPreTrainedModel ) :
# """
# Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
# Args:
# config: IdeficsConfig
# """
def __init__ ( self , config : IdeficsConfig , weights ) :
super ( ) . __init__ ( config )
self . config = config
self . padding_idx = config . pad_token_id
self . vocab_size = config . vocab_size
self . embed_tokens = IdeficsDecoupledPartialTPEmbedding (
config = config ,
weights = weights ,
)
self . image_size = config . vision_config . image_size
self . vision_config = config . vision_config
self . vision_model = IdeficsVisionTransformer (
prefix = " model.vision_model " ,
config = config . vision_config ,
weights = weights ,
)
# Perceiver Resampler
if config . use_resampler :
perceiver_config = config . perceiver_config
self . perceiver_resampler = IdeficsPerceiverResampler (
prefix = f " model.perceiver_resampler " ,
config = config ,
embed_dim = config . vision_config . embed_dim ,
depth = perceiver_config . resampler_depth ,
n_heads = perceiver_config . resampler_n_heads ,
head_dim = perceiver_config . resampler_head_dim ,
n_latents = perceiver_config . resampler_n_latents ,
weights = weights ,
)
self . layers = nn . ModuleList (
[
IdeficsDecoderLayer ( layer_id , config , weights )
for layer_id in range ( config . num_hidden_layers )
]
)
self . cross_layer_interval = config . cross_layer_interval
num_cross_layers = config . num_hidden_layers / / self . cross_layer_interval
self . gated_cross_attn_layers = nn . ModuleList (
[
IdeficsGatedCrossAttentionLayer ( layer_id , config , weights )
2023-09-27 10:22:09 +00:00
for layer_id in range ( num_cross_layers )
]
2023-08-17 12:38:49 +00:00
)
# self.gradient_checkpointing = False
2023-09-27 10:22:09 +00:00
self . norm = IdeficsRMSNorm (
prefix = f " model.norm " , weights = weights , eps = config . rms_norm_eps
)
2023-08-17 12:38:49 +00:00
# self.gradient_checkpointing = False
# Initialize weights and apply final processing
# self.post_init()
# self.freeze_relevant_params(config)
# def freeze_relevant_params(self, config=None):
# if config is None:
# config = self.config
# if config.freeze_text_layers:
# self.freeze_text_layers(config.freeze_text_module_exceptions)
# if config.freeze_vision_layers:
# freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
# def freeze_text_layers(self, module_exceptions=[]):
# for module in [self.layers, self.norm]:
# freeze_model(module, module_exceptions=module_exceptions)
# def freeze_vision_layers(self, module_exceptions=[]):
# freeze_model(self.vision_model, module_exceptions=module_exceptions)
# def get_input_embeddings(self):
# return self.embed_tokens
# def set_input_embeddings(self, value):
# self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
2023-09-27 10:22:09 +00:00
def _prepare_decoder_attention_mask (
self , attention_mask , input_shape , inputs_embeds , past_key_values_length
) :
2023-08-17 12:38:49 +00:00
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape [ - 1 ] > 1 :
combined_attention_mask = _make_causal_mask (
input_shape ,
inputs_embeds . dtype ,
device = inputs_embeds . device ,
past_key_values_length = past_key_values_length ,
)
if attention_mask is not None :
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
2023-09-27 10:22:09 +00:00
expanded_attn_mask = _expand_mask (
attention_mask , inputs_embeds . dtype , tgt_len = input_shape [ - 1 ]
) . to ( inputs_embeds . device )
2023-08-17 12:38:49 +00:00
combined_attention_mask = (
2023-09-27 10:22:09 +00:00
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
2023-08-17 12:38:49 +00:00
)
return combined_attention_mask
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ List [ torch . FloatTensor ] ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
pixel_values : Optional [ torch . FloatTensor ] = None ,
image_hidden_states : Optional [ torch . FloatTensor ] = None ,
image_embeddings : Optional [ torch . FloatTensor ] = None ,
image_attention_mask : Optional [ torch . Tensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple , BaseModelOutputWithPastImage ] :
device = input_ids . device if input_ids is not None else inputs_embeds . device
2023-09-27 10:22:09 +00:00
output_attentions = (
output_attentions
if output_attentions is not None
else self . config . output_attentions
)
2023-08-17 12:38:49 +00:00
output_hidden_states = (
2023-09-27 10:22:09 +00:00
output_hidden_states
if output_hidden_states is not None
else self . config . output_hidden_states
2023-08-17 12:38:49 +00:00
)
use_cache = use_cache if use_cache is not None else self . config . use_cache
2023-09-27 10:22:09 +00:00
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
)
2023-08-17 12:38:49 +00:00
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None :
2023-09-27 10:22:09 +00:00
raise ValueError (
" You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time "
)
2023-08-17 12:38:49 +00:00
elif input_ids is not None :
batch_size , seq_length = input_ids . shape
elif inputs_embeds is not None :
batch_size , seq_length , _ = inputs_embeds . shape
else :
2023-09-27 10:22:09 +00:00
raise ValueError (
" You have to specify either decoder_input_ids or decoder_inputs_embeds "
)
2023-08-17 12:38:49 +00:00
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None :
past_key_values_length = past_key_values [ 0 ] [ 0 ] . shape [ 2 ]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is not None and position_ids is None :
# create position_ids on the fly for batch generation
position_ids = attention_mask . long ( ) . cumsum ( - 1 ) - 1
position_ids . masked_fill_ ( attention_mask == 0 , 1 )
elif position_ids is None :
device = input_ids . device if input_ids is not None else inputs_embeds . device
position_ids = torch . arange (
2023-09-27 10:22:09 +00:00
past_key_values_length ,
seq_length + past_key_values_length ,
dtype = torch . long ,
device = device ,
2023-08-17 12:38:49 +00:00
)
position_ids = position_ids . unsqueeze ( 0 ) . view ( - 1 , seq_length )
else :
position_ids = position_ids . view ( - 1 , seq_length ) . long ( )
no_images = False
if image_hidden_states is None :
if pixel_values is None and image_embeddings is None :
2023-09-27 10:22:09 +00:00
raise ValueError (
" Either pixel_values and image_embeddings have to be not-None. "
)
2023-08-17 12:38:49 +00:00
elif pixel_values is not None and image_embeddings is not None :
2023-09-27 10:22:09 +00:00
raise ValueError (
" You cannot specify both pixel_values and image_embeddings at the same time "
)
2023-08-17 12:38:49 +00:00
elif pixel_values is not None :
no_images = len ( torch . nonzero ( pixel_values ) ) == 0
2023-09-27 10:22:09 +00:00
pixel_values = pixel_values . to (
dtype = self . dtype , device = device
) # fp16 compatibility
2023-08-17 12:38:49 +00:00
batch_size , num_images = pixel_values . shape [ : 2 ]
2023-09-27 10:22:09 +00:00
pixel_values = pixel_values . contiguous ( ) . view (
batch_size * num_images , * pixel_values . shape [ 2 : ]
)
2023-08-17 12:38:49 +00:00
# Get sequence from the vision encoder
2023-09-27 10:22:09 +00:00
image_hidden_states = self . vision_model (
pixel_values = pixel_values
) . last_hidden_state
2023-08-17 12:38:49 +00:00
elif image_embeddings is not None :
2023-09-27 10:22:09 +00:00
(
batch_size ,
num_images ,
image_seq_len ,
image_hidden_size ,
) = image_embeddings . size ( )
image_hidden_states = image_embeddings . to (
dtype = self . dtype , device = input_ids . device
)
image_hidden_states = image_hidden_states . view (
batch_size * num_images , image_seq_len , image_hidden_size
)
2023-08-17 12:38:49 +00:00
if self . config . use_resampler :
image_hidden_states = self . perceiver_resampler ( image_hidden_states )
2023-09-27 10:22:09 +00:00
image_seq_len , image_hidden_size = image_hidden_states . size (
1
) , image_hidden_states . size ( 2 )
image_hidden_states = image_hidden_states . view (
batch_size , num_images * image_seq_len , image_hidden_size
)
2023-08-17 12:38:49 +00:00
else :
no_images = False
num_images = pixel_values . shape [ 1 ]
image_seq_len = image_hidden_states . shape [ 1 ] / / num_images
# # Hack to use the model in full language modeling mode
# image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
# Make image_attention_mask compatible with hidden states
text_seq_len = image_attention_mask . size ( 1 )
image_attention_mask = image_attention_mask . unsqueeze ( - 1 )
image_attention_mask = image_attention_mask . repeat ( 1 , 1 , 1 , image_seq_len )
2023-09-27 10:22:09 +00:00
image_attention_mask = image_attention_mask . view (
batch_size , text_seq_len , num_images * image_seq_len
)
2023-08-17 12:38:49 +00:00
image_batch_size , image_sequence_length , _ = image_hidden_states . size ( )
image_hidden_shape = ( image_batch_size , image_sequence_length )
if image_attention_mask is None :
image_attention_mask = torch . ones ( image_hidden_shape , device = device )
image_attention_mask = self . invert_attention_mask ( image_attention_mask )
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
# if image_hidden_states is not None:
# else:
# image_attention_mask = None
if inputs_embeds is None :
inputs_embeds = self . embed_tokens ( input_ids )
# embed positions
if attention_mask is None :
attention_mask = torch . ones (
2023-09-27 10:22:09 +00:00
( batch_size , seq_length_with_past ) ,
dtype = torch . bool ,
device = inputs_embeds . device ,
2023-08-17 12:38:49 +00:00
)
attention_mask = self . _prepare_decoder_attention_mask (
2023-09-27 10:22:09 +00:00
attention_mask ,
( batch_size , seq_length ) ,
inputs_embeds ,
past_key_values_length ,
2023-08-17 12:38:49 +00:00
)
hidden_states = inputs_embeds
# if self.gradient_checkpointing and self.training:
# if use_cache:
# logger.warning_once(
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
# )
# use_cache = False
# decoder layers
all_hidden_states = ( ) if output_hidden_states else None
all_self_attns = ( ) if output_attentions else None
next_decoder_cache = ( ) if use_cache else None
for idx , decoder_layer in enumerate ( self . layers ) :
if output_hidden_states :
all_hidden_states + = ( hidden_states , )
2023-09-27 10:22:09 +00:00
past_key_value = (
past_key_values [ idx ] if past_key_values is not None else None
)
2023-08-17 12:38:49 +00:00
def vblock (
main_block ,
hidden_states ,
attention_mask ,
position_ids ,
past_key_value ,
image_hidden_states ,
image_attention_mask ,
output_attentions ,
use_cache ,
no_images ,
layer_idx ,
cross_layer_interval ,
gated_cross_attn_layers ,
) :
# TODO(ls): Add cross attention values to respective lists
if layer_idx % cross_layer_interval == 0 :
xblock = gated_cross_attn_layers [ layer_idx / / cross_layer_interval ]
outputs = xblock (
hidden_states ,
attention_mask = attention_mask ,
image_hidden_states = image_hidden_states ,
image_attention_mask = image_attention_mask ,
output_attentions = output_attentions ,
use_cache = use_cache ,
past_key_value = None , # not implemented
no_images = no_images ,
)
hidden_states = outputs [ 0 ]
layer_outputs = main_block (
hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_value = past_key_value ,
output_attentions = output_attentions ,
use_cache = use_cache ,
)
return layer_outputs
# if self.gradient_checkpointing and self.training:
# past_key_value = None
# if use_cache:
# logger.warning_once(
# "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
# )
# use_cache = False
# layer_outputs = torch.utils.checkpoint.checkpoint(
# vblock,
# decoder_layer,
# hidden_states,
# attention_mask,
# position_ids,
# past_key_value,
# image_hidden_states,
# image_attention_mask,
# output_attentions,
# use_cache,
# no_images,
# idx,
# self.cross_layer_interval,
# self.gated_cross_attn_layers,
# )
# else:
layer_outputs = vblock (
decoder_layer ,
hidden_states ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_value = past_key_value ,
image_hidden_states = image_hidden_states ,
image_attention_mask = image_attention_mask ,
output_attentions = output_attentions ,
use_cache = use_cache ,
no_images = no_images ,
layer_idx = idx ,
cross_layer_interval = self . cross_layer_interval ,
gated_cross_attn_layers = self . gated_cross_attn_layers ,
)
hidden_states = layer_outputs [ 0 ]
if use_cache :
next_decoder_cache + = ( layer_outputs [ 2 if output_attentions else 1 ] , )
if output_attentions :
all_self_attns + = ( layer_outputs [ 1 ] , )
hidden_states = self . norm ( hidden_states )
# add hidden states from the last decoder layer
if output_hidden_states :
all_hidden_states + = ( hidden_states , )
next_cache = next_decoder_cache if use_cache else None
if not return_dict :
2023-09-27 10:22:09 +00:00
return tuple (
v
for v in [ hidden_states , next_cache , all_hidden_states , all_self_attns ]
if v is not None
)
2023-08-17 12:38:49 +00:00
return BaseModelOutputWithPastImage (
last_hidden_state = hidden_states ,
past_key_values = next_cache ,
hidden_states = all_hidden_states ,
attentions = all_self_attns ,
image_hidden_states = image_hidden_states ,
)
class IdeficsForVisionText2Text ( IdeficsPreTrainedModel ) :
def __init__ (
self ,
config ,
weights ,
) :
super ( ) . __init__ ( config )
self . model = IdeficsModel (
config = config ,
weights = weights ,
)
self . lm_head = IdeficsDecoupledTensorParallelLinear (
config = config ,
weights = weights ,
)
def forward (
self ,
input_ids : torch . LongTensor = None ,
attention_mask : Optional [ torch . Tensor ] = None ,
position_ids : Optional [ torch . LongTensor ] = None ,
past_key_values : Optional [ List [ torch . FloatTensor ] ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
pixel_values : Optional [ torch . FloatTensor ] = None ,
image_embeddings : Optional [ torch . FloatTensor ] = None ,
2023-09-27 10:22:09 +00:00
image_hidden_states : Optional [ torch . FloatTensor ] = None ,
2023-08-17 12:38:49 +00:00
image_attention_mask : Optional [ torch . Tensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple , CausalLMOutputWithPastImage ] :
r """
Args :
labels ( ` torch . LongTensor ` of shape ` ( batch_size , sequence_length ) ` , * optional * ) :
Labels for computing the masked language modeling loss . Indices should either be in ` [ 0 , . . . ,
config . vocab_size ] ` or - 100 ( see ` input_ids ` docstring ) . Tokens with indices set to ` - 100 ` are ignored
( masked ) , the loss is only computed for the tokens with labels in ` [ 0 , . . . , config . vocab_size ] ` .
Returns :
Example :
` ` ` python
>> > from transformers import AutoTokenizer , LlamaForCausalLM
>> > model = LlamaForCausalLM . from_pretrained ( PATH_TO_CONVERTED_WEIGHTS )
>> > tokenizer = AutoTokenizer . from_pretrained ( PATH_TO_CONVERTED_TOKENIZER )
>> > prompt = " Hey, are you consciours? Can you talk to me? "
>> > inputs = tokenizer ( prompt , return_tensors = " pt " )
>> > # Generate
>> > generate_ids = model . generate ( inputs . input_ids , max_length = 30 )
>> > tokenizer . batch_decode ( generate_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False ) [ 0 ]
" Hey, are you consciours? Can you talk to me? \n I ' m not consciours, but I can talk to you. "
` ` ` """
2023-09-27 10:22:09 +00:00
output_attentions = (
output_attentions
if output_attentions is not None
else self . config . output_attentions
)
2023-08-17 12:38:49 +00:00
output_hidden_states = (
2023-09-27 10:22:09 +00:00
output_hidden_states
if output_hidden_states is not None
else self . config . output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
2023-08-17 12:38:49 +00:00
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self . model (
input_ids = input_ids ,
attention_mask = attention_mask ,
position_ids = position_ids ,
past_key_values = past_key_values ,
inputs_embeds = inputs_embeds ,
pixel_values = pixel_values ,
image_embeddings = image_embeddings ,
image_hidden_states = image_hidden_states ,
image_attention_mask = image_attention_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
hidden_states = outputs [ 0 ]
2024-02-26 18:49:28 +00:00
logits , speculative_logits = self . lm_head ( hidden_states )
2023-08-17 12:38:49 +00:00
loss = None
2024-02-26 18:49:28 +00:00
return (
CausalLMOutputWithPastImage (
loss = loss ,
logits = logits ,
past_key_values = outputs . past_key_values ,
hidden_states = outputs . hidden_states ,
attentions = outputs . attentions ,
image_hidden_states = outputs . image_hidden_states ,
) ,
speculative_logits ,
2023-08-17 12:38:49 +00:00
)
def prepare_inputs_for_generation ( self , input_ids , past = None , * * kwargs ) :
inputs = prepare_inputs_for_generation ( input_ids , past = past , * * kwargs )
unwanted_kwargs = [ " token_type_ids " ]
for kwarg in unwanted_kwargs :
inputs . pop ( kwarg , None )
return inputs
@staticmethod
def _expand_inputs_for_generation (
* args ,
* * model_kwargs ,
) :
return expand_inputs_for_generation ( * args , * * model_kwargs )
@staticmethod
2023-09-27 10:22:09 +00:00
def _update_model_kwargs_for_generation (
outputs , model_kwargs , is_encoder_decoder = False
) :
return update_model_kwargs_for_generation (
outputs , model_kwargs , is_encoder_decoder = is_encoder_decoder
)
2023-08-17 12:38:49 +00:00
@staticmethod
def _reorder_cache ( past , beam_idx ) :
reordered_past = ( )
for layer_past in past :
2023-09-27 10:22:09 +00:00
reordered_past + = (
tuple (
past_state . index_select ( 0 , beam_idx ) for past_state in layer_past
) ,
)
2023-08-17 12:38:49 +00:00
return reordered_past