2023-06-08 12:51:52 +00:00
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """
import copy
import math
import warnings
from typing import Optional , Tuple , Union
2023-07-12 08:40:32 +00:00
from loguru import logger
2023-06-08 12:51:52 +00:00
import torch
import torch . distributed
from torch import nn
from torch . nn import CrossEntropyLoss
from transformers . activations import ACT2FN
from transformers . modeling_outputs import (
BaseModelOutput ,
BaseModelOutputWithPastAndCrossAttentions ,
Seq2SeqLMOutput ,
)
from transformers . modeling_utils import PreTrainedModel
from transformers . pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers . utils import (
is_torch_fx_proxy ,
)
from transformers import T5Config
2024-05-13 10:44:30 +00:00
from text_generation_server . layers import (
2023-06-08 12:51:52 +00:00
TensorParallelColumnLinear ,
TensorParallelEmbedding ,
TensorParallelRowLinear ,
2024-02-26 18:49:28 +00:00
SpeculativeHead ,
2023-06-08 12:51:52 +00:00
)
2024-07-26 14:29:09 +00:00
# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument ` head_mask ` was split into two arguments ` head_mask ` and ` decoder_head_mask ` . Currently ,
` decoder_head_mask ` is set to copy ` head_mask ` , but this feature is deprecated and will be removed in future versions .
If you do not want to use any ` decoder_head_mask ` now , please set ` decoder_head_mask = torch . ones ( num_layers ,
num_heads ) ` .
"""
2023-06-08 12:51:52 +00:00
class PartialTPEmbedding ( nn . Module ) :
def __init__ ( self , prefix : str , weights ) :
super ( ) . __init__ ( )
weight = weights . get_sharded ( f " { prefix } .weight " , dim = 1 )
self . weight = nn . Parameter ( weight )
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
return torch . nn . functional . embedding ( input , self . weight )
@torch.jit.script
def layer_norm ( hidden_states , weight , epsilon ) :
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states . to ( torch . float32 ) . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt ( variance + epsilon )
# convert into half-precision if necessary
if weight . dtype in [ torch . float16 , torch . bfloat16 ] :
hidden_states = hidden_states . to ( weight . dtype )
return weight * hidden_states
class T5LayerNorm ( nn . Module ) :
def __init__ ( self , prefix , weights , eps = 1e-6 ) :
"""
Construct a layernorm module in the T5 style . No bias and no subtraction of mean .
"""
super ( ) . __init__ ( )
weight = weights . get_tensor ( f " { prefix } .weight " )
self . weight = nn . Parameter ( weight )
self . variance_epsilon = torch . tensor ( eps )
def forward ( self , hidden_states ) :
return layer_norm ( hidden_states , self . weight , self . variance_epsilon )
try :
from apex . normalization import FusedRMSNorm
T5LayerNorm = FusedRMSNorm # noqa
logger . info (
" Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm "
)
except ImportError :
# using the normal T5LayerNorm
pass
except Exception :
logger . warning ( " discovered apex but it failed to load, falling back to T5LayerNorm " )
pass
ALL_LAYERNORM_LAYERS . append ( T5LayerNorm )
class T5DenseActDense ( nn . Module ) :
def __init__ ( self , config : T5Config , prefix , weights ) :
super ( ) . __init__ ( )
self . wi = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .wi " , weights = weights , bias = False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config . quantize
_dtype = weights . dtype
weights . dtype = torch . float32
config . quantize = None
self . wo_cast = ( torch . float32 , _dtype )
self . wo = TensorParallelRowLinear . load (
config , prefix = f " { prefix } .wo " , weights = weights , bias = False
)
weights . dtype = _dtype
config . quantize = _q
self . dropout = nn . Dropout ( config . dropout_rate )
self . act = (
ACT2FN [ config . dense_act_fn ]
if " gelu " not in config . dense_act_fn
else lambda x : torch . nn . functional . gelu ( x , approximate = " tanh " )
)
def forward ( self , hidden_states ) :
hidden_states = self . wi ( hidden_states )
hidden_states = self . act ( hidden_states )
hidden_states = self . dropout ( hidden_states )
hidden_states = hidden_states . to ( dtype = self . wo_cast [ 0 ] )
hidden_states = self . wo ( hidden_states )
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
class T5DenseGatedActDense ( nn . Module ) :
def __init__ ( self , config : T5Config , prefix , weights ) :
super ( ) . __init__ ( )
self . wi_0 = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .wi_0 " , weights = weights , bias = False
)
self . wi_1 = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .wi_1 " , weights = weights , bias = False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config . quantize
_dtype = weights . dtype
weights . dtype = torch . float32
config . quantize = None
self . wo_cast = ( torch . float32 , _dtype )
self . wo = TensorParallelRowLinear . load (
config , prefix = f " { prefix } .wo " , weights = weights , bias = False
)
weights . dtype = _dtype
config . quantize = _q
self . dropout = nn . Dropout ( config . dropout_rate )
self . act = (
ACT2FN [ config . dense_act_fn ]
if " gelu " not in config . dense_act_fn
else lambda x : torch . nn . functional . gelu ( x , approximate = " tanh " )
)
def forward ( self , hidden_states ) :
hidden_gelu = self . act ( self . wi_0 ( hidden_states ) )
hidden_linear = self . wi_1 ( hidden_states )
hidden_states = hidden_gelu * hidden_linear
hidden_states = self . dropout ( hidden_states )
hidden_states = hidden_states . to ( dtype = self . wo_cast [ 0 ] )
hidden_states = self . wo ( hidden_states )
# XXX: Recasting is already done within the layer norm.
# Casting back to float16 here modifies results
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
class T5LayerFF ( nn . Module ) :
def __init__ ( self , config : T5Config , prefix , weights ) :
super ( ) . __init__ ( )
if config . is_gated_act :
self . DenseReluDense = T5DenseGatedActDense (
config , prefix = f " { prefix } .DenseReluDense " , weights = weights
)
else :
self . DenseReluDense = T5DenseActDense (
config , prefix = f " { prefix } .DenseReluDense " , weights = weights
)
self . layer_norm = T5LayerNorm (
prefix = f " { prefix } .layer_norm " ,
weights = weights ,
eps = config . layer_norm_epsilon ,
)
self . dropout = nn . Dropout ( config . dropout_rate )
def forward ( self , hidden_states ) :
forwarded_states = self . layer_norm ( hidden_states )
forwarded_states = self . DenseReluDense ( forwarded_states )
hidden_states = hidden_states + self . dropout ( forwarded_states )
return hidden_states
class T5Attention ( nn . Module ) :
def __init__ (
self , config : T5Config , prefix , weights , has_relative_attention_bias = False
) :
super ( ) . __init__ ( )
self . is_decoder = config . is_decoder
self . has_relative_attention_bias = has_relative_attention_bias
self . relative_attention_num_buckets = config . relative_attention_num_buckets
self . relative_attention_max_distance = config . relative_attention_max_distance
self . d_model = config . d_model
self . key_value_proj_dim = config . d_kv
self . n_heads = config . num_heads
self . dropout = config . dropout_rate
self . inner_dim = self . n_heads * self . key_value_proj_dim
process_group = weights . process_group
# Mesh TensorFlow initialization to avoid scaling before softmax
assert self . n_heads % process_group . size ( ) == 0
self . q = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .q " , weights = weights , bias = False
)
self . k = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .k " , weights = weights , bias = False
)
self . v = TensorParallelColumnLinear . load (
config , prefix = f " { prefix } .v " , weights = weights , bias = False
)
self . o = TensorParallelRowLinear . load (
config , prefix = f " { prefix } .o " , weights = weights , bias = False
)
2023-07-10 12:47:15 +00:00
if self . n_heads % weights . process_group . size ( ) != 0 :
raise ValueError (
f " `n_heads` must be divisible by `num_shards` (got `n_heads`: { self . n_heads } "
f " and `num_shards`: { weights . process_group . size ( ) } "
)
2023-06-08 12:51:52 +00:00
self . n_heads = self . n_heads / / process_group . size ( )
self . inner_dim = self . inner_dim / / process_group . size ( )
if self . has_relative_attention_bias :
self . relative_attention_bias = PartialTPEmbedding (
prefix = f " { prefix } .relative_attention_bias " , weights = weights
)
@staticmethod
def _relative_position_bucket (
relative_position , bidirectional = True , num_buckets = 32 , max_distance = 128
) :
"""
Adapted from Mesh Tensorflow :
https : / / github . com / tensorflow / mesh / blob / 0 cb87fe07da627bf0b7e60475d59f95ed6b5be3d / mesh_tensorflow / transformer / transformer_layers . py #L593
Translate relative position to a bucket number for relative attention . The relative position is defined as
memory_position - query_position , i . e . the distance in tokens from the attending position to the attended - to
position . If bidirectional = False , then positive relative positions are invalid . We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions . All relative
positions > = max_distance map to the same bucket . All relative positions < = - max_distance map to the same bucket .
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args :
relative_position : an int32 Tensor
bidirectional : a boolean - whether the attention is bidirectional
num_buckets : an integer
max_distance : an integer
Returns :
a Tensor with the same shape as relative_position , containing int32 values in the range [ 0 , num_buckets )
"""
relative_buckets = 0
if bidirectional :
num_buckets / / = 2
relative_buckets + = ( relative_position > 0 ) . to ( torch . long ) * num_buckets
relative_position = torch . abs ( relative_position )
else :
relative_position = - torch . min (
relative_position , torch . zeros_like ( relative_position )
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets / / 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch . log ( relative_position . float ( ) / max_exact )
/ math . log ( max_distance / max_exact )
* ( num_buckets - max_exact )
) . to ( torch . long )
relative_position_if_large = torch . min (
relative_position_if_large ,
torch . full_like ( relative_position_if_large , num_buckets - 1 ) ,
)
relative_buckets + = torch . where (
is_small , relative_position , relative_position_if_large
)
return relative_buckets
def compute_bias ( self , query_length , key_length , device = None ) :
""" Compute binned relative position bias """
if device is None :
device = self . relative_attention_bias . weight . device
context_position = torch . arange ( query_length , dtype = torch . long , device = device ) [
: , None
]
memory_position = torch . arange ( key_length , dtype = torch . long , device = device ) [
None , :
]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
relative_position_bucket = self . _relative_position_bucket (
relative_position , # shape (query_length, key_length)
bidirectional = ( not self . is_decoder ) ,
num_buckets = self . relative_attention_num_buckets ,
max_distance = self . relative_attention_max_distance ,
)
values = self . relative_attention_bias (
relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = values . permute ( [ 2 , 0 , 1 ] ) . unsqueeze (
0
) # shape (1, num_heads, query_length, key_length)
return values
def forward (
self ,
hidden_states ,
mask = None ,
key_value_states = None ,
position_bias = None ,
past_key_value = None ,
layer_head_mask = None ,
query_length = None ,
use_cache = False ,
output_attentions = False ,
) :
"""
Self - attention ( if key_value_states is None ) or attention over source sentence ( provided by key_value_states ) .
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size , seq_length = hidden_states . shape [ : 2 ]
real_seq_length = seq_length
if past_key_value is not None :
assert (
len ( past_key_value ) == 2
) , f " past_key_value should have 2 past states: keys and values. Got { len ( past_key_value ) } past states "
real_seq_length + = (
past_key_value [ 0 ] . shape [ 2 ] if query_length is None else query_length
)
key_length = (
real_seq_length if key_value_states is None else key_value_states . shape [ 1 ]
)
def shape ( states ) :
""" projection """
return states . view (
batch_size , - 1 , self . n_heads , self . key_value_proj_dim
) . transpose ( 1 , 2 )
def unshape ( states ) :
""" reshape """
return (
states . transpose ( 1 , 2 ) . contiguous ( ) . view ( batch_size , - 1 , self . inner_dim )
)
def project ( hidden_states , proj_layer , key_value_states , past_key_value ) :
""" projects hidden states correctly to key/query states """
if key_value_states is None :
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape ( proj_layer ( hidden_states ) )
elif past_key_value is None :
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape ( proj_layer ( key_value_states ) )
if past_key_value is not None :
if key_value_states is None :
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch . cat ( [ past_key_value , hidden_states ] , dim = 2 )
elif past_key_value . shape [ 2 ] != key_value_states . shape [ 1 ] :
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape ( proj_layer ( key_value_states ) )
else :
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape (
self . q ( hidden_states )
) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project (
hidden_states ,
self . k ,
key_value_states ,
past_key_value [ 0 ] if past_key_value is not None else None ,
)
value_states = project (
hidden_states ,
self . v ,
key_value_states ,
past_key_value [ 1 ] if past_key_value is not None else None ,
)
# compute scores
scores = torch . matmul (
query_states , key_states . transpose ( 3 , 2 )
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None :
if not self . has_relative_attention_bias :
position_bias = torch . zeros (
( 1 , self . n_heads , real_seq_length , key_length ) ,
device = scores . device ,
dtype = scores . dtype ,
)
else :
position_bias = self . compute_bias (
real_seq_length , key_length , device = scores . device
)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None :
position_bias = position_bias [ : , : , - hidden_states . size ( 1 ) : , : ]
if mask is not None :
position_bias = (
position_bias + mask
) # (batch_size, n_heads, seq_length, key_length)
position_bias_masked = position_bias
scores + = position_bias_masked
attn_weights = nn . functional . softmax ( scores . float ( ) , dim = - 1 ) . type_as (
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn . functional . dropout (
attn_weights , p = self . dropout , training = self . training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None :
attn_weights = attn_weights * layer_head_mask
attn_output = unshape (
torch . matmul ( attn_weights , value_states )
) # (batch_size, seq_length, dim)
attn_output = self . o ( attn_output )
present_key_value_state = (
( key_states , value_states ) if ( self . is_decoder and use_cache ) else None
)
outputs = ( attn_output , ) + ( present_key_value_state , ) + ( position_bias , )
if output_attentions :
outputs = outputs + ( attn_weights , )
return outputs
class T5LayerSelfAttention ( nn . Module ) :
def __init__ ( self , config , prefix , weights , has_relative_attention_bias = False ) :
super ( ) . __init__ ( )
self . SelfAttention = T5Attention (
config ,
prefix = f " { prefix } .SelfAttention " ,
weights = weights ,
has_relative_attention_bias = has_relative_attention_bias ,
)
self . layer_norm = T5LayerNorm (
prefix = f " { prefix } .layer_norm " ,
weights = weights ,
eps = config . layer_norm_epsilon ,
)
self . dropout = nn . Dropout ( config . dropout_rate )
def forward (
self ,
hidden_states ,
attention_mask = None ,
position_bias = None ,
layer_head_mask = None ,
past_key_value = None ,
use_cache = False ,
output_attentions = False ,
) :
normed_hidden_states = self . layer_norm ( hidden_states )
attention_output = self . SelfAttention (
normed_hidden_states ,
mask = attention_mask ,
position_bias = position_bias ,
layer_head_mask = layer_head_mask ,
past_key_value = past_key_value ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
hidden_states = hidden_states + self . dropout ( attention_output [ 0 ] )
outputs = ( hidden_states , ) + attention_output [
1 :
] # add attentions if we output them
return outputs
class T5LayerCrossAttention ( nn . Module ) :
def __init__ ( self , config , prefix , weights ) :
super ( ) . __init__ ( )
self . EncDecAttention = T5Attention (
config ,
prefix = f " { prefix } .EncDecAttention " ,
weights = weights ,
has_relative_attention_bias = False ,
)
self . layer_norm = T5LayerNorm (
prefix = f " { prefix } .layer_norm " ,
weights = weights ,
eps = config . layer_norm_epsilon ,
)
self . dropout = nn . Dropout ( config . dropout_rate )
def forward (
self ,
hidden_states ,
key_value_states ,
attention_mask = None ,
position_bias = None ,
layer_head_mask = None ,
past_key_value = None ,
use_cache = False ,
query_length = None ,
output_attentions = False ,
) :
normed_hidden_states = self . layer_norm ( hidden_states )
attention_output = self . EncDecAttention (
normed_hidden_states ,
mask = attention_mask ,
key_value_states = key_value_states ,
position_bias = position_bias ,
layer_head_mask = layer_head_mask ,
past_key_value = past_key_value ,
use_cache = use_cache ,
query_length = query_length ,
output_attentions = output_attentions ,
)
layer_output = hidden_states + self . dropout ( attention_output [ 0 ] )
outputs = ( layer_output , ) + attention_output [
1 :
] # add attentions if we output them
return outputs
class T5Block ( nn . Module ) :
def __init__ ( self , config , prefix , weights , has_relative_attention_bias : bool ) :
super ( ) . __init__ ( )
self . is_decoder = config . is_decoder
self . layer = nn . ModuleList ( )
self . layer . append (
T5LayerSelfAttention (
config ,
prefix = f " { prefix } .layer.0 " ,
weights = weights ,
has_relative_attention_bias = has_relative_attention_bias ,
)
)
if self . is_decoder :
i = 2
self . layer . append (
T5LayerCrossAttention (
config , prefix = f " { prefix } .layer.1 " , weights = weights
)
)
else :
i = 1
self . layer . append (
T5LayerFF ( config , prefix = f " { prefix } .layer. { i } " , weights = weights )
)
def forward (
self ,
hidden_states ,
attention_mask = None ,
position_bias = None ,
encoder_hidden_states = None ,
encoder_attention_mask = None ,
encoder_decoder_position_bias = None ,
layer_head_mask = None ,
cross_attn_layer_head_mask = None ,
past_key_value = None ,
use_cache = False ,
output_attentions = False ,
return_dict = True ,
) :
if past_key_value is not None :
if not self . is_decoder :
logger . warning (
" `past_key_values` is passed to the encoder. Please make sure this is intended. "
)
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len ( past_key_value ) != expected_num_past_key_values :
raise ValueError (
f " There should be { expected_num_past_key_values } past states. "
f " { ' 2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ' ' } "
f " Got { len ( past_key_value ) } past key / value states "
)
self_attn_past_key_value = past_key_value [ : 2 ]
cross_attn_past_key_value = past_key_value [ 2 : ]
else :
self_attn_past_key_value , cross_attn_past_key_value = None , None
self_attention_outputs = self . layer [ 0 ] (
hidden_states ,
attention_mask = attention_mask ,
position_bias = position_bias ,
layer_head_mask = layer_head_mask ,
past_key_value = self_attn_past_key_value ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
hidden_states , present_key_value_state = self_attention_outputs [ : 2 ]
attention_outputs = self_attention_outputs [
2 :
] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states . dtype == torch . float16 :
clamp_value = torch . where (
torch . isinf ( hidden_states ) . any ( ) ,
torch . finfo ( hidden_states . dtype ) . max - 1000 ,
torch . finfo ( hidden_states . dtype ) . max ,
)
hidden_states = torch . clamp (
hidden_states , min = - clamp_value , max = clamp_value
)
do_cross_attention = self . is_decoder and encoder_hidden_states is not None
if do_cross_attention :
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None :
query_length = present_key_value_state [ 0 ] . shape [ 2 ]
else :
query_length = None
cross_attention_outputs = self . layer [ 1 ] (
hidden_states ,
key_value_states = encoder_hidden_states ,
attention_mask = encoder_attention_mask ,
position_bias = encoder_decoder_position_bias ,
layer_head_mask = cross_attn_layer_head_mask ,
past_key_value = cross_attn_past_key_value ,
query_length = query_length ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
hidden_states = cross_attention_outputs [ 0 ]
# clamp inf values to enable fp16 training
if hidden_states . dtype == torch . float16 :
clamp_value = torch . where (
torch . isinf ( hidden_states ) . any ( ) ,
torch . finfo ( hidden_states . dtype ) . max - 1000 ,
torch . finfo ( hidden_states . dtype ) . max ,
)
hidden_states = torch . clamp (
hidden_states , min = - clamp_value , max = clamp_value
)
# Combine self attn and cross attn key value states
if present_key_value_state is not None :
present_key_value_state = (
present_key_value_state + cross_attention_outputs [ 1 ]
)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs [ 2 : ]
# Apply Feed Forward layer
hidden_states = self . layer [ - 1 ] ( hidden_states )
# clamp inf values to enable fp16 training
if hidden_states . dtype == torch . float16 :
clamp_value = torch . where (
torch . isinf ( hidden_states ) . any ( ) ,
torch . finfo ( hidden_states . dtype ) . max - 1000 ,
torch . finfo ( hidden_states . dtype ) . max ,
)
hidden_states = torch . clamp (
hidden_states , min = - clamp_value , max = clamp_value
)
outputs = ( hidden_states , )
if use_cache :
outputs = outputs + ( present_key_value_state , ) + attention_outputs
else :
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class T5PreTrainedModel ( PreTrainedModel ) :
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models .
"""
config_class = T5Config
def _shift_right ( self , input_ids ) :
decoder_start_token_id = self . config . decoder_start_token_id
pad_token_id = self . config . pad_token_id
assert decoder_start_token_id is not None , (
" self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
" See T5 docs for more information "
)
# shift inputs to the right
if is_torch_fx_proxy ( input_ids ) :
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch . full (
input_ids . shape [ : - 1 ] + ( 1 , ) , decoder_start_token_id
)
shifted_input_ids = torch . cat (
[ shifted_input_ids , input_ids [ . . . , : - 1 ] ] , dim = - 1
)
else :
shifted_input_ids = input_ids . new_zeros ( input_ids . shape )
shifted_input_ids [ . . . , 1 : ] = input_ids [ . . . , : - 1 ] . clone ( )
shifted_input_ids [ . . . , 0 ] = decoder_start_token_id
assert (
pad_token_id is not None
) , " self.model.config.pad_token_id has to be defined. "
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids . masked_fill_ ( shifted_input_ids == - 100 , pad_token_id )
return shifted_input_ids
class T5Stack ( T5PreTrainedModel ) :
def __init__ ( self , config , prefix , weights , embed_tokens ) :
super ( ) . __init__ ( config )
self . is_decoder = config . is_decoder
self . embed_tokens = embed_tokens
self . block = nn . ModuleList (
[
T5Block (
config ,
prefix = f " { prefix } .block. { layer_id } " ,
weights = weights ,
has_relative_attention_bias = ( layer_id == 0 ) ,
)
for layer_id in range ( config . num_layers )
]
)
self . final_layer_norm = T5LayerNorm (
prefix = f " { prefix } .final_layer_norm " ,
weights = weights ,
eps = config . layer_norm_epsilon ,
)
self . dropout = nn . Dropout ( config . dropout_rate )
def forward (
self ,
input_ids = None ,
attention_mask = None ,
encoder_hidden_states = None ,
encoder_attention_mask = None ,
inputs_embeds = None ,
head_mask = None ,
cross_attn_head_mask = None ,
past_key_values = None ,
use_cache = None ,
output_attentions = None ,
output_hidden_states = None ,
return_dict = None ,
) :
# Model parallel
use_cache = use_cache if use_cache is not None else self . config . use_cache
output_attentions = (
output_attentions
if output_attentions is not None
else self . config . output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self . config . output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
)
if input_ids is not None and inputs_embeds is not None :
err_msg_prefix = " decoder_ " if self . is_decoder else " "
raise ValueError (
f " You cannot specify both { err_msg_prefix } input_ids and { err_msg_prefix } inputs_embeds at the same time "
)
elif input_ids is not None :
input_shape = input_ids . size ( )
input_ids = input_ids . view ( - 1 , input_shape [ - 1 ] )
elif inputs_embeds is not None :
input_shape = inputs_embeds . size ( ) [ : - 1 ]
else :
err_msg_prefix = " decoder_ " if self . is_decoder else " "
raise ValueError (
f " You have to specify either { err_msg_prefix } input_ids or { err_msg_prefix } inputs_embeds "
)
if inputs_embeds is None :
assert (
self . embed_tokens is not None
) , " You have to initialize the model with valid token embeddings "
inputs_embeds = self . embed_tokens ( input_ids )
batch_size , seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
past_key_values [ 0 ] [ 0 ] . shape [ 2 ] + seq_length
if past_key_values is not None
else seq_length
)
if use_cache is True :
assert (
self . is_decoder
) , f " `use_cache` can only be set to `True` if { self } is used as a decoder "
if attention_mask is None :
attention_mask = torch . ones (
batch_size , mask_seq_length , device = inputs_embeds . device
)
if (
self . is_decoder
and encoder_attention_mask is None
and encoder_hidden_states is not None
) :
encoder_seq_length = encoder_hidden_states . shape [ 1 ]
encoder_attention_mask = torch . ones (
batch_size ,
encoder_seq_length ,
device = inputs_embeds . device ,
dtype = torch . long ,
)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None :
past_key_values = [ None ] * len ( self . block )
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self . get_extended_attention_mask (
attention_mask , input_shape
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self . is_decoder and encoder_hidden_states is not None :
(
encoder_batch_size ,
encoder_sequence_length ,
_ ,
) = encoder_hidden_states . size ( )
encoder_hidden_shape = ( encoder_batch_size , encoder_sequence_length )
if encoder_attention_mask is None :
encoder_attention_mask = torch . ones (
encoder_hidden_shape , device = inputs_embeds . device
)
encoder_extended_attention_mask = self . invert_attention_mask (
encoder_attention_mask
)
else :
encoder_extended_attention_mask = None
# Prepare head mask if needed
head_mask = self . get_head_mask ( head_mask , self . config . num_layers )
cross_attn_head_mask = self . get_head_mask (
cross_attn_head_mask , self . config . num_layers
)
present_key_value_states = ( ) if use_cache else None
all_hidden_states = ( ) if output_hidden_states else None
all_attentions = ( ) if output_attentions else None
all_cross_attentions = ( ) if ( output_attentions and self . is_decoder ) else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self . dropout ( inputs_embeds )
for i , ( layer_module , past_key_value ) in enumerate (
zip ( self . block , past_key_values )
) :
layer_head_mask = head_mask [ i ]
cross_attn_layer_head_mask = cross_attn_head_mask [ i ]
# Model parallel
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
layer_outputs = layer_module (
hidden_states ,
attention_mask = extended_attention_mask ,
position_bias = position_bias ,
encoder_hidden_states = encoder_hidden_states ,
encoder_attention_mask = encoder_extended_attention_mask ,
encoder_decoder_position_bias = encoder_decoder_position_bias ,
layer_head_mask = layer_head_mask ,
cross_attn_layer_head_mask = cross_attn_layer_head_mask ,
past_key_value = past_key_value ,
use_cache = use_cache ,
output_attentions = output_attentions ,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False :
layer_outputs = layer_outputs [ : 1 ] + ( None , ) + layer_outputs [ 1 : ]
hidden_states , present_key_value_state = layer_outputs [ : 2 ]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs [ 2 ]
if self . is_decoder and encoder_hidden_states is not None :
encoder_decoder_position_bias = layer_outputs [
4 if output_attentions else 3
]
# append next layer key value states
if use_cache :
present_key_value_states = present_key_value_states + (
present_key_value_state ,
)
if output_attentions :
all_attentions = all_attentions + ( layer_outputs [ 3 ] , )
if self . is_decoder :
all_cross_attentions = all_cross_attentions + ( layer_outputs [ 5 ] , )
hidden_states = self . final_layer_norm ( hidden_states )
hidden_states = self . dropout ( hidden_states )
# Add last layer
if output_hidden_states :
all_hidden_states = all_hidden_states + ( hidden_states , )
if not return_dict :
return tuple (
v
for v in [
hidden_states ,
present_key_value_states ,
all_hidden_states ,
all_attentions ,
all_cross_attentions ,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions (
last_hidden_state = hidden_states ,
past_key_values = present_key_value_states ,
hidden_states = all_hidden_states ,
attentions = all_attentions ,
cross_attentions = all_cross_attentions ,
)
class T5ForConditionalGeneration ( T5PreTrainedModel ) :
def __init__ ( self , config : T5Config , weights ) :
super ( ) . __init__ ( config )
self . model_dim = config . d_model
2023-07-12 08:01:42 +00:00
self . shared = TensorParallelEmbedding ( prefix = " shared " , weights = weights )
2023-06-08 12:51:52 +00:00
encoder_config = copy . deepcopy ( config )
encoder_config . is_decoder = False
encoder_config . use_cache = False
encoder_config . is_encoder_decoder = False
self . encoder = T5Stack (
config = encoder_config ,
prefix = " encoder " ,
weights = weights ,
embed_tokens = self . shared ,
)
decoder_config = copy . deepcopy ( config )
decoder_config . is_decoder = True
decoder_config . is_encoder_decoder = False
decoder_config . num_layers = config . num_decoder_layers
self . decoder = T5Stack (
config = decoder_config ,
prefix = " decoder " ,
weights = weights ,
embed_tokens = self . shared ,
)
2023-09-25 10:22:28 +00:00
try :
2024-02-26 18:49:28 +00:00
self . lm_head = SpeculativeHead . load (
2023-09-25 10:22:28 +00:00
config , prefix = " lm_head " , weights = weights
)
except RuntimeError :
# Some models like t5-small were saved with shared weights unlike flan
# Since they are declared as the same arch we have no choice but hope
# that this is OK instead of using a proper flag.
2024-02-26 18:49:28 +00:00
self . lm_head = SpeculativeHead . load (
2023-09-25 10:22:28 +00:00
config , prefix = " shared " , weights = weights
)
2023-06-08 12:51:52 +00:00
def forward (
self ,
input_ids : Optional [ torch . LongTensor ] = None ,
attention_mask : Optional [ torch . FloatTensor ] = None ,
decoder_input_ids : Optional [ torch . LongTensor ] = None ,
decoder_attention_mask : Optional [ torch . BoolTensor ] = None ,
head_mask : Optional [ torch . FloatTensor ] = None ,
decoder_head_mask : Optional [ torch . FloatTensor ] = None ,
cross_attn_head_mask : Optional [ torch . Tensor ] = None ,
encoder_outputs : Optional [ Tuple [ Tuple [ torch . Tensor ] ] ] = None ,
past_key_values : Optional [ Tuple [ Tuple [ torch . Tensor ] ] ] = None ,
inputs_embeds : Optional [ torch . FloatTensor ] = None ,
decoder_inputs_embeds : Optional [ torch . FloatTensor ] = None ,
labels : Optional [ torch . LongTensor ] = None ,
use_cache : Optional [ bool ] = None ,
output_attentions : Optional [ bool ] = None ,
output_hidden_states : Optional [ bool ] = None ,
return_dict : Optional [ bool ] = None ,
) - > Union [ Tuple [ torch . FloatTensor ] , Seq2SeqLMOutput ] :
use_cache = use_cache if use_cache is not None else self . config . use_cache
return_dict = (
return_dict if return_dict is not None else self . config . use_return_dict
)
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None :
if self . config . num_layers == self . config . num_decoder_layers :
warnings . warn ( __HEAD_MASK_WARNING_MSG , FutureWarning )
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None :
# Convert encoder inputs in embeddings if needed
encoder_outputs = self . encoder (
input_ids = input_ids ,
attention_mask = attention_mask ,
inputs_embeds = inputs_embeds ,
head_mask = head_mask ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
elif return_dict and not isinstance ( encoder_outputs , BaseModelOutput ) :
encoder_outputs = BaseModelOutput (
last_hidden_state = encoder_outputs [ 0 ] ,
hidden_states = encoder_outputs [ 1 ] if len ( encoder_outputs ) > 1 else None ,
attentions = encoder_outputs [ 2 ] if len ( encoder_outputs ) > 2 else None ,
)
hidden_states = encoder_outputs [ 0 ]
if (
labels is not None
and decoder_input_ids is None
and decoder_inputs_embeds is None
) :
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self . _shift_right ( labels )
# Decode
decoder_outputs = self . decoder (
input_ids = decoder_input_ids ,
attention_mask = decoder_attention_mask ,
inputs_embeds = decoder_inputs_embeds ,
past_key_values = past_key_values ,
encoder_hidden_states = hidden_states ,
encoder_attention_mask = attention_mask ,
head_mask = decoder_head_mask ,
cross_attn_head_mask = cross_attn_head_mask ,
use_cache = use_cache ,
output_attentions = output_attentions ,
output_hidden_states = output_hidden_states ,
return_dict = return_dict ,
)
sequence_output = decoder_outputs [ 0 ]
if self . config . tie_word_embeddings :
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * ( self . model_dim * * - 0.5 )
2024-02-26 18:49:28 +00:00
logits , speculative_logits = self . lm_head ( sequence_output )
2023-06-08 12:51:52 +00:00
loss = None
if labels is not None :
loss_fct = CrossEntropyLoss ( ignore_index = - 100 )
# move labels to correct device to enable PP
2024-07-26 14:29:09 +00:00
labels = labels . to ( logits . device )
loss = loss_fct ( logits . view ( - 1 , logits . size ( - 1 ) ) , labels . view ( - 1 ) )
2023-06-08 12:51:52 +00:00
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict :
2024-07-26 14:29:09 +00:00
output = ( logits , ) + decoder_outputs [ 1 : ] + encoder_outputs
2023-06-08 12:51:52 +00:00
return ( ( loss , ) + output ) if loss is not None else output
2024-02-26 18:49:28 +00:00
return (
Seq2SeqLMOutput (
loss = loss ,
logits = logits ,
past_key_values = decoder_outputs . past_key_values ,
decoder_hidden_states = decoder_outputs . hidden_states ,
decoder_attentions = decoder_outputs . attentions ,
cross_attentions = decoder_outputs . cross_attentions ,
encoder_last_hidden_state = encoder_outputs . last_hidden_state ,
encoder_hidden_states = encoder_outputs . hidden_states ,
encoder_attentions = encoder_outputs . attentions ,
) ,
speculative_logits ,
2023-06-08 12:51:52 +00:00
)
def prepare_inputs_for_generation (
self ,
input_ids ,
past_key_values = None ,
attention_mask = None ,
head_mask = None ,
decoder_head_mask = None ,
decoder_attention_mask = None ,
cross_attn_head_mask = None ,
use_cache = None ,
encoder_outputs = None ,
* * kwargs ,
) :
# cut decoder_input_ids if past is used
if past_key_values is not None :
input_ids = input_ids [ : , - 1 : ]
return {
" decoder_input_ids " : input_ids ,
" past_key_values " : past_key_values ,
" encoder_outputs " : encoder_outputs ,
" attention_mask " : attention_mask ,
" head_mask " : head_mask ,
" decoder_head_mask " : decoder_head_mask ,
" decoder_attention_mask " : decoder_attention_mask ,
" cross_attn_head_mask " : cross_attn_head_mask ,
" use_cache " : use_cache ,
}
def prepare_decoder_input_ids_from_labels ( self , labels : torch . Tensor ) :
return self . _shift_right ( labels )
def _reorder_cache ( self , past_key_values , beam_idx ) :
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past_key_values is None :
logger . warning (
" You might want to consider setting `use_cache=True` to speed up decoding "
)
return past_key_values
reordered_decoder_past = ( )
for layer_past_states in past_key_values :
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ( )
for layer_past_state in layer_past_states :
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state . index_select (
0 , beam_idx . to ( layer_past_state . device )
) ,
)
assert reordered_layer_past_states [ 0 ] . shape == layer_past_states [ 0 ] . shape
assert len ( reordered_layer_past_states ) == len ( layer_past_states )
reordered_decoder_past = reordered_decoder_past + (
reordered_layer_past_states ,
)
return reordered_decoder_past