mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
1208 lines
46 KiB
Python
1208 lines
46 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch T5 model."""
|
|
|
|
import copy
|
|
import math
|
|
import warnings
|
|
from typing import Optional, Tuple, Union
|
|
|
|
from loguru import logger
|
|
|
|
import torch
|
|
import torch.distributed
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
Seq2SeqLMOutput,
|
|
)
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
from transformers.utils import (
|
|
is_torch_fx_proxy,
|
|
)
|
|
from transformers import T5Config
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
TensorParallelHead,
|
|
)
|
|
|
|
|
|
class PartialTPEmbedding(nn.Module):
|
|
def __init__(self, prefix: str, weights):
|
|
super().__init__()
|
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
|
self.weight = nn.Parameter(weight)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
return torch.nn.functional.embedding(input, self.weight)
|
|
|
|
|
|
@torch.jit.script
|
|
def layer_norm(hidden_states, weight, epsilon):
|
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
|
# half-precision inputs is done in fp32
|
|
|
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + epsilon)
|
|
|
|
# convert into half-precision if necessary
|
|
if weight.dtype in [torch.float16, torch.bfloat16]:
|
|
hidden_states = hidden_states.to(weight.dtype)
|
|
|
|
return weight * hidden_states
|
|
|
|
|
|
class T5LayerNorm(nn.Module):
|
|
def __init__(self, prefix, weights, eps=1e-6):
|
|
"""
|
|
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
|
"""
|
|
super().__init__()
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = torch.tensor(eps)
|
|
|
|
def forward(self, hidden_states):
|
|
return layer_norm(hidden_states, self.weight, self.variance_epsilon)
|
|
|
|
|
|
try:
|
|
from apex.normalization import FusedRMSNorm
|
|
|
|
T5LayerNorm = FusedRMSNorm # noqa
|
|
|
|
logger.info(
|
|
"Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm"
|
|
)
|
|
except ImportError:
|
|
# using the normal T5LayerNorm
|
|
pass
|
|
except Exception:
|
|
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
|
|
pass
|
|
|
|
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
|
|
|
|
|
|
class T5DenseActDense(nn.Module):
|
|
def __init__(self, config: T5Config, prefix, weights):
|
|
super().__init__()
|
|
self.wi = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.wi", weights=weights, bias=False
|
|
)
|
|
|
|
### XXX: T5 models do not handle well both f16 and quantization.
|
|
### Overidding specifically this layer for that reason.
|
|
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
|
|
### https://github.com/huggingface/transformers/issues/20287
|
|
_q = config.quantize
|
|
_dtype = weights.dtype
|
|
weights.dtype = torch.float32
|
|
config.quantize = None
|
|
self.wo_cast = (torch.float32, _dtype)
|
|
self.wo = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
|
)
|
|
weights.dtype = _dtype
|
|
config.quantize = _q
|
|
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
self.act = (
|
|
ACT2FN[config.dense_act_fn]
|
|
if "gelu" not in config.dense_act_fn
|
|
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.wi(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
|
hidden_states = self.wo(hidden_states)
|
|
# XXX: Recasting is already done within the layer norm.
|
|
# Casting back to float16 here modifies results
|
|
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
|
return hidden_states
|
|
|
|
|
|
class T5DenseGatedActDense(nn.Module):
|
|
def __init__(self, config: T5Config, prefix, weights):
|
|
super().__init__()
|
|
self.wi_0 = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.wi_0", weights=weights, bias=False
|
|
)
|
|
self.wi_1 = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
|
|
)
|
|
### XXX: T5 models do not handle well both f16 and quantization.
|
|
### Overidding specifically this layer for that reason.
|
|
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
|
|
### https://github.com/huggingface/transformers/issues/20287
|
|
_q = config.quantize
|
|
_dtype = weights.dtype
|
|
weights.dtype = torch.float32
|
|
config.quantize = None
|
|
self.wo_cast = (torch.float32, _dtype)
|
|
self.wo = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.wo", weights=weights, bias=False
|
|
)
|
|
weights.dtype = _dtype
|
|
config.quantize = _q
|
|
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
self.act = (
|
|
ACT2FN[config.dense_act_fn]
|
|
if "gelu" not in config.dense_act_fn
|
|
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_gelu = self.act(self.wi_0(hidden_states))
|
|
hidden_linear = self.wi_1(hidden_states)
|
|
hidden_states = hidden_gelu * hidden_linear
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
|
|
hidden_states = self.wo(hidden_states)
|
|
# XXX: Recasting is already done within the layer norm.
|
|
# Casting back to float16 here modifies results
|
|
# hidden_states = hidden_states.to(dtype=self.wo_cast[1])
|
|
return hidden_states
|
|
|
|
|
|
class T5LayerFF(nn.Module):
|
|
def __init__(self, config: T5Config, prefix, weights):
|
|
super().__init__()
|
|
if config.is_gated_act:
|
|
self.DenseReluDense = T5DenseGatedActDense(
|
|
config, prefix=f"{prefix}.DenseReluDense", weights=weights
|
|
)
|
|
else:
|
|
self.DenseReluDense = T5DenseActDense(
|
|
config, prefix=f"{prefix}.DenseReluDense", weights=weights
|
|
)
|
|
|
|
self.layer_norm = T5LayerNorm(
|
|
prefix=f"{prefix}.layer_norm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
def forward(self, hidden_states):
|
|
forwarded_states = self.layer_norm(hidden_states)
|
|
forwarded_states = self.DenseReluDense(forwarded_states)
|
|
hidden_states = hidden_states + self.dropout(forwarded_states)
|
|
return hidden_states
|
|
|
|
|
|
class T5Attention(nn.Module):
|
|
def __init__(
|
|
self, config: T5Config, prefix, weights, has_relative_attention_bias=False
|
|
):
|
|
super().__init__()
|
|
self.is_decoder = config.is_decoder
|
|
self.has_relative_attention_bias = has_relative_attention_bias
|
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
|
self.relative_attention_max_distance = config.relative_attention_max_distance
|
|
self.d_model = config.d_model
|
|
self.key_value_proj_dim = config.d_kv
|
|
self.n_heads = config.num_heads
|
|
self.dropout = config.dropout_rate
|
|
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
|
|
|
process_group = weights.process_group
|
|
# Mesh TensorFlow initialization to avoid scaling before softmax
|
|
assert self.n_heads % process_group.size() == 0
|
|
self.q = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.q", weights=weights, bias=False
|
|
)
|
|
self.k = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.k", weights=weights, bias=False
|
|
)
|
|
self.v = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.v", weights=weights, bias=False
|
|
)
|
|
self.o = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
|
)
|
|
if self.n_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.n_heads = self.n_heads // process_group.size()
|
|
self.inner_dim = self.inner_dim // process_group.size()
|
|
|
|
if self.has_relative_attention_bias:
|
|
self.relative_attention_bias = PartialTPEmbedding(
|
|
prefix=f"{prefix}.relative_attention_bias", weights=weights
|
|
)
|
|
|
|
@staticmethod
|
|
def _relative_position_bucket(
|
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
|
):
|
|
"""
|
|
Adapted from Mesh Tensorflow:
|
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
|
|
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
|
|
|
Args:
|
|
relative_position: an int32 Tensor
|
|
bidirectional: a boolean - whether the attention is bidirectional
|
|
num_buckets: an integer
|
|
max_distance: an integer
|
|
|
|
Returns:
|
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
|
"""
|
|
relative_buckets = 0
|
|
if bidirectional:
|
|
num_buckets //= 2
|
|
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
|
relative_position = torch.abs(relative_position)
|
|
else:
|
|
relative_position = -torch.min(
|
|
relative_position, torch.zeros_like(relative_position)
|
|
)
|
|
# now relative_position is in the range [0, inf)
|
|
|
|
# half of the buckets are for exact increments in positions
|
|
max_exact = num_buckets // 2
|
|
is_small = relative_position < max_exact
|
|
|
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
|
relative_position_if_large = max_exact + (
|
|
torch.log(relative_position.float() / max_exact)
|
|
/ math.log(max_distance / max_exact)
|
|
* (num_buckets - max_exact)
|
|
).to(torch.long)
|
|
relative_position_if_large = torch.min(
|
|
relative_position_if_large,
|
|
torch.full_like(relative_position_if_large, num_buckets - 1),
|
|
)
|
|
|
|
relative_buckets += torch.where(
|
|
is_small, relative_position, relative_position_if_large
|
|
)
|
|
return relative_buckets
|
|
|
|
def compute_bias(self, query_length, key_length, device=None):
|
|
"""Compute binned relative position bias"""
|
|
if device is None:
|
|
device = self.relative_attention_bias.weight.device
|
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
|
|
:, None
|
|
]
|
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
|
|
None, :
|
|
]
|
|
relative_position = (
|
|
memory_position - context_position
|
|
) # shape (query_length, key_length)
|
|
relative_position_bucket = self._relative_position_bucket(
|
|
relative_position, # shape (query_length, key_length)
|
|
bidirectional=(not self.is_decoder),
|
|
num_buckets=self.relative_attention_num_buckets,
|
|
max_distance=self.relative_attention_max_distance,
|
|
)
|
|
values = self.relative_attention_bias(
|
|
relative_position_bucket
|
|
) # shape (query_length, key_length, num_heads)
|
|
values = values.permute([2, 0, 1]).unsqueeze(
|
|
0
|
|
) # shape (1, num_heads, query_length, key_length)
|
|
return values
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
mask=None,
|
|
key_value_states=None,
|
|
position_bias=None,
|
|
past_key_value=None,
|
|
layer_head_mask=None,
|
|
query_length=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
):
|
|
"""
|
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
|
"""
|
|
# Input is (batch_size, seq_length, dim)
|
|
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
|
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
|
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
real_seq_length = seq_length
|
|
|
|
if past_key_value is not None:
|
|
assert (
|
|
len(past_key_value) == 2
|
|
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
|
real_seq_length += (
|
|
past_key_value[0].shape[2] if query_length is None else query_length
|
|
)
|
|
|
|
key_length = (
|
|
real_seq_length if key_value_states is None else key_value_states.shape[1]
|
|
)
|
|
|
|
def shape(states):
|
|
"""projection"""
|
|
return states.view(
|
|
batch_size, -1, self.n_heads, self.key_value_proj_dim
|
|
).transpose(1, 2)
|
|
|
|
def unshape(states):
|
|
"""reshape"""
|
|
return (
|
|
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
|
)
|
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
|
"""projects hidden states correctly to key/query states"""
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(hidden_states))
|
|
elif past_key_value is None:
|
|
# cross-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(key_value_states))
|
|
|
|
if past_key_value is not None:
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, key_length, dim_per_head)
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
|
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
|
# checking that the `sequence_length` of the `past_key_value` is the same as
|
|
# the provided `key_value_states` to support prefix tuning
|
|
# cross-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(key_value_states))
|
|
else:
|
|
# cross-attn
|
|
hidden_states = past_key_value
|
|
return hidden_states
|
|
|
|
# get query states
|
|
query_states = shape(
|
|
self.q(hidden_states)
|
|
) # (batch_size, n_heads, seq_length, dim_per_head)
|
|
|
|
# get key/value states
|
|
key_states = project(
|
|
hidden_states,
|
|
self.k,
|
|
key_value_states,
|
|
past_key_value[0] if past_key_value is not None else None,
|
|
)
|
|
value_states = project(
|
|
hidden_states,
|
|
self.v,
|
|
key_value_states,
|
|
past_key_value[1] if past_key_value is not None else None,
|
|
)
|
|
|
|
# compute scores
|
|
scores = torch.matmul(
|
|
query_states, key_states.transpose(3, 2)
|
|
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
|
|
|
if position_bias is None:
|
|
if not self.has_relative_attention_bias:
|
|
position_bias = torch.zeros(
|
|
(1, self.n_heads, real_seq_length, key_length),
|
|
device=scores.device,
|
|
dtype=scores.dtype,
|
|
)
|
|
else:
|
|
position_bias = self.compute_bias(
|
|
real_seq_length, key_length, device=scores.device
|
|
)
|
|
|
|
# if key and values are already calculated
|
|
# we want only the last query position bias
|
|
if past_key_value is not None:
|
|
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
|
|
|
if mask is not None:
|
|
position_bias = (
|
|
position_bias + mask
|
|
) # (batch_size, n_heads, seq_length, key_length)
|
|
|
|
position_bias_masked = position_bias
|
|
|
|
scores += position_bias_masked
|
|
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
|
scores
|
|
) # (batch_size, n_heads, seq_length, key_length)
|
|
attn_weights = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=self.training
|
|
) # (batch_size, n_heads, seq_length, key_length)
|
|
|
|
# Mask heads if we want to
|
|
if layer_head_mask is not None:
|
|
attn_weights = attn_weights * layer_head_mask
|
|
|
|
attn_output = unshape(
|
|
torch.matmul(attn_weights, value_states)
|
|
) # (batch_size, seq_length, dim)
|
|
attn_output = self.o(attn_output)
|
|
|
|
present_key_value_state = (
|
|
(key_states, value_states) if (self.is_decoder and use_cache) else None
|
|
)
|
|
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
|
|
|
if output_attentions:
|
|
outputs = outputs + (attn_weights,)
|
|
return outputs
|
|
|
|
|
|
class T5LayerSelfAttention(nn.Module):
|
|
def __init__(self, config, prefix, weights, has_relative_attention_bias=False):
|
|
super().__init__()
|
|
self.SelfAttention = T5Attention(
|
|
config,
|
|
prefix=f"{prefix}.SelfAttention",
|
|
weights=weights,
|
|
has_relative_attention_bias=has_relative_attention_bias,
|
|
)
|
|
self.layer_norm = T5LayerNorm(
|
|
prefix=f"{prefix}.layer_norm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
layer_head_mask=None,
|
|
past_key_value=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
):
|
|
normed_hidden_states = self.layer_norm(hidden_states)
|
|
attention_output = self.SelfAttention(
|
|
normed_hidden_states,
|
|
mask=attention_mask,
|
|
position_bias=position_bias,
|
|
layer_head_mask=layer_head_mask,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = hidden_states + self.dropout(attention_output[0])
|
|
outputs = (hidden_states,) + attention_output[
|
|
1:
|
|
] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class T5LayerCrossAttention(nn.Module):
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
self.EncDecAttention = T5Attention(
|
|
config,
|
|
prefix=f"{prefix}.EncDecAttention",
|
|
weights=weights,
|
|
has_relative_attention_bias=False,
|
|
)
|
|
self.layer_norm = T5LayerNorm(
|
|
prefix=f"{prefix}.layer_norm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
key_value_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
layer_head_mask=None,
|
|
past_key_value=None,
|
|
use_cache=False,
|
|
query_length=None,
|
|
output_attentions=False,
|
|
):
|
|
normed_hidden_states = self.layer_norm(hidden_states)
|
|
attention_output = self.EncDecAttention(
|
|
normed_hidden_states,
|
|
mask=attention_mask,
|
|
key_value_states=key_value_states,
|
|
position_bias=position_bias,
|
|
layer_head_mask=layer_head_mask,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
query_length=query_length,
|
|
output_attentions=output_attentions,
|
|
)
|
|
layer_output = hidden_states + self.dropout(attention_output[0])
|
|
outputs = (layer_output,) + attention_output[
|
|
1:
|
|
] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class T5Block(nn.Module):
|
|
def __init__(self, config, prefix, weights, has_relative_attention_bias: bool):
|
|
super().__init__()
|
|
self.is_decoder = config.is_decoder
|
|
self.layer = nn.ModuleList()
|
|
self.layer.append(
|
|
T5LayerSelfAttention(
|
|
config,
|
|
prefix=f"{prefix}.layer.0",
|
|
weights=weights,
|
|
has_relative_attention_bias=has_relative_attention_bias,
|
|
)
|
|
)
|
|
if self.is_decoder:
|
|
i = 2
|
|
self.layer.append(
|
|
T5LayerCrossAttention(
|
|
config, prefix=f"{prefix}.layer.1", weights=weights
|
|
)
|
|
)
|
|
else:
|
|
i = 1
|
|
|
|
self.layer.append(
|
|
T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
position_bias=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
encoder_decoder_position_bias=None,
|
|
layer_head_mask=None,
|
|
cross_attn_layer_head_mask=None,
|
|
past_key_value=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
return_dict=True,
|
|
):
|
|
if past_key_value is not None:
|
|
if not self.is_decoder:
|
|
logger.warning(
|
|
"`past_key_values` is passed to the encoder. Please make sure this is intended."
|
|
)
|
|
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
|
|
|
if len(past_key_value) != expected_num_past_key_values:
|
|
raise ValueError(
|
|
f"There should be {expected_num_past_key_values} past states. "
|
|
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
|
f"Got {len(past_key_value)} past key / value states"
|
|
)
|
|
|
|
self_attn_past_key_value = past_key_value[:2]
|
|
cross_attn_past_key_value = past_key_value[2:]
|
|
else:
|
|
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
|
|
|
self_attention_outputs = self.layer[0](
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_bias=position_bias,
|
|
layer_head_mask=layer_head_mask,
|
|
past_key_value=self_attn_past_key_value,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
|
attention_outputs = self_attention_outputs[
|
|
2:
|
|
] # Keep self-attention outputs and relative position weights
|
|
|
|
# clamp inf values to enable fp16 training
|
|
if hidden_states.dtype == torch.float16:
|
|
clamp_value = torch.where(
|
|
torch.isinf(hidden_states).any(),
|
|
torch.finfo(hidden_states.dtype).max - 1000,
|
|
torch.finfo(hidden_states.dtype).max,
|
|
)
|
|
hidden_states = torch.clamp(
|
|
hidden_states, min=-clamp_value, max=clamp_value
|
|
)
|
|
|
|
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
|
if do_cross_attention:
|
|
# the actual query length is unknown for cross attention
|
|
# if using past key value states. Need to inject it here
|
|
if present_key_value_state is not None:
|
|
query_length = present_key_value_state[0].shape[2]
|
|
else:
|
|
query_length = None
|
|
|
|
cross_attention_outputs = self.layer[1](
|
|
hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
position_bias=encoder_decoder_position_bias,
|
|
layer_head_mask=cross_attn_layer_head_mask,
|
|
past_key_value=cross_attn_past_key_value,
|
|
query_length=query_length,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = cross_attention_outputs[0]
|
|
|
|
# clamp inf values to enable fp16 training
|
|
if hidden_states.dtype == torch.float16:
|
|
clamp_value = torch.where(
|
|
torch.isinf(hidden_states).any(),
|
|
torch.finfo(hidden_states.dtype).max - 1000,
|
|
torch.finfo(hidden_states.dtype).max,
|
|
)
|
|
hidden_states = torch.clamp(
|
|
hidden_states, min=-clamp_value, max=clamp_value
|
|
)
|
|
|
|
# Combine self attn and cross attn key value states
|
|
if present_key_value_state is not None:
|
|
present_key_value_state = (
|
|
present_key_value_state + cross_attention_outputs[1]
|
|
)
|
|
|
|
# Keep cross-attention outputs and relative position weights
|
|
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
|
|
|
# Apply Feed Forward layer
|
|
hidden_states = self.layer[-1](hidden_states)
|
|
|
|
# clamp inf values to enable fp16 training
|
|
if hidden_states.dtype == torch.float16:
|
|
clamp_value = torch.where(
|
|
torch.isinf(hidden_states).any(),
|
|
torch.finfo(hidden_states.dtype).max - 1000,
|
|
torch.finfo(hidden_states.dtype).max,
|
|
)
|
|
hidden_states = torch.clamp(
|
|
hidden_states, min=-clamp_value, max=clamp_value
|
|
)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if use_cache:
|
|
outputs = outputs + (present_key_value_state,) + attention_outputs
|
|
else:
|
|
outputs = outputs + attention_outputs
|
|
|
|
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
|
|
|
|
|
class T5PreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = T5Config
|
|
|
|
def _shift_right(self, input_ids):
|
|
decoder_start_token_id = self.config.decoder_start_token_id
|
|
pad_token_id = self.config.pad_token_id
|
|
|
|
assert decoder_start_token_id is not None, (
|
|
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
|
|
" See T5 docs for more information"
|
|
)
|
|
|
|
# shift inputs to the right
|
|
if is_torch_fx_proxy(input_ids):
|
|
# Item assignment is not supported natively for proxies.
|
|
shifted_input_ids = torch.full(
|
|
input_ids.shape[:-1] + (1,), decoder_start_token_id
|
|
)
|
|
shifted_input_ids = torch.cat(
|
|
[shifted_input_ids, input_ids[..., :-1]], dim=-1
|
|
)
|
|
else:
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
|
shifted_input_ids[..., 0] = decoder_start_token_id
|
|
|
|
assert (
|
|
pad_token_id is not None
|
|
), "self.model.config.pad_token_id has to be defined."
|
|
# replace possible -100 values in labels by `pad_token_id`
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
|
|
|
return shifted_input_ids
|
|
|
|
|
|
class T5Stack(T5PreTrainedModel):
|
|
def __init__(self, config, prefix, weights, embed_tokens):
|
|
super().__init__(config)
|
|
|
|
self.is_decoder = config.is_decoder
|
|
|
|
self.embed_tokens = embed_tokens
|
|
self.block = nn.ModuleList(
|
|
[
|
|
T5Block(
|
|
config,
|
|
prefix=f"{prefix}.block.{layer_id}",
|
|
weights=weights,
|
|
has_relative_attention_bias=(layer_id == 0),
|
|
)
|
|
for layer_id in range(config.num_layers)
|
|
]
|
|
)
|
|
self.final_layer_norm = T5LayerNorm(
|
|
prefix=f"{prefix}.final_layer_norm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
inputs_embeds=None,
|
|
head_mask=None,
|
|
cross_attn_head_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
# Model parallel
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
|
raise ValueError(
|
|
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
|
)
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
|
raise ValueError(
|
|
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
|
|
)
|
|
|
|
if inputs_embeds is None:
|
|
assert (
|
|
self.embed_tokens is not None
|
|
), "You have to initialize the model with valid token embeddings"
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
batch_size, seq_length = input_shape
|
|
|
|
# required mask seq length can be calculated via length of past
|
|
mask_seq_length = (
|
|
past_key_values[0][0].shape[2] + seq_length
|
|
if past_key_values is not None
|
|
else seq_length
|
|
)
|
|
|
|
if use_cache is True:
|
|
assert (
|
|
self.is_decoder
|
|
), f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
batch_size, mask_seq_length, device=inputs_embeds.device
|
|
)
|
|
if (
|
|
self.is_decoder
|
|
and encoder_attention_mask is None
|
|
and encoder_hidden_states is not None
|
|
):
|
|
encoder_seq_length = encoder_hidden_states.shape[1]
|
|
encoder_attention_mask = torch.ones(
|
|
batch_size,
|
|
encoder_seq_length,
|
|
device=inputs_embeds.device,
|
|
dtype=torch.long,
|
|
)
|
|
|
|
# initialize past_key_values with `None` if past does not exist
|
|
if past_key_values is None:
|
|
past_key_values = [None] * len(self.block)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
extended_attention_mask = self.get_extended_attention_mask(
|
|
attention_mask, input_shape
|
|
)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if self.is_decoder and encoder_hidden_states is not None:
|
|
(
|
|
encoder_batch_size,
|
|
encoder_sequence_length,
|
|
_,
|
|
) = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(
|
|
encoder_hidden_shape, device=inputs_embeds.device
|
|
)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(
|
|
encoder_attention_mask
|
|
)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
|
cross_attn_head_mask = self.get_head_mask(
|
|
cross_attn_head_mask, self.config.num_layers
|
|
)
|
|
present_key_value_states = () if use_cache else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
|
position_bias = None
|
|
encoder_decoder_position_bias = None
|
|
|
|
hidden_states = self.dropout(inputs_embeds)
|
|
|
|
for i, (layer_module, past_key_value) in enumerate(
|
|
zip(self.block, past_key_values)
|
|
):
|
|
layer_head_mask = head_mask[i]
|
|
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
|
# Model parallel
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_outputs = layer_module(
|
|
hidden_states,
|
|
attention_mask=extended_attention_mask,
|
|
position_bias=position_bias,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
|
layer_head_mask=layer_head_mask,
|
|
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
|
past_key_value=past_key_value,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
# layer_outputs is a tuple with:
|
|
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
|
if use_cache is False:
|
|
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
|
|
|
hidden_states, present_key_value_state = layer_outputs[:2]
|
|
|
|
# We share the position biases between the layers - the first layer store them
|
|
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
|
# (cross-attention position bias), (cross-attention weights)
|
|
position_bias = layer_outputs[2]
|
|
if self.is_decoder and encoder_hidden_states is not None:
|
|
encoder_decoder_position_bias = layer_outputs[
|
|
4 if output_attentions else 3
|
|
]
|
|
# append next layer key value states
|
|
if use_cache:
|
|
present_key_value_states = present_key_value_states + (
|
|
present_key_value_state,
|
|
)
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[3],)
|
|
if self.is_decoder:
|
|
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
|
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
# Add last layer
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
present_key_value_states,
|
|
all_hidden_states,
|
|
all_attentions,
|
|
all_cross_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=present_key_value_states,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
class T5ForConditionalGeneration(T5PreTrainedModel):
|
|
def __init__(self, config: T5Config, weights):
|
|
super().__init__(config)
|
|
self.model_dim = config.d_model
|
|
|
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
|
|
|
encoder_config = copy.deepcopy(config)
|
|
encoder_config.is_decoder = False
|
|
encoder_config.use_cache = False
|
|
encoder_config.is_encoder_decoder = False
|
|
self.encoder = T5Stack(
|
|
config=encoder_config,
|
|
prefix="encoder",
|
|
weights=weights,
|
|
embed_tokens=self.shared,
|
|
)
|
|
|
|
decoder_config = copy.deepcopy(config)
|
|
decoder_config.is_decoder = True
|
|
decoder_config.is_encoder_decoder = False
|
|
decoder_config.num_layers = config.num_decoder_layers
|
|
self.decoder = T5Stack(
|
|
config=decoder_config,
|
|
prefix="decoder",
|
|
weights=weights,
|
|
embed_tokens=self.shared,
|
|
)
|
|
|
|
self.lm_head = TensorParallelHead.load(
|
|
config, prefix="lm_head", weights=weights
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
|
if head_mask is not None and decoder_head_mask is None:
|
|
if self.config.num_layers == self.config.num_decoder_layers:
|
|
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
|
decoder_head_mask = head_mask
|
|
|
|
# Encode if needed (training, first prediction pass)
|
|
if encoder_outputs is None:
|
|
# Convert encoder inputs in embeddings if needed
|
|
encoder_outputs = self.encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
encoder_outputs = BaseModelOutput(
|
|
last_hidden_state=encoder_outputs[0],
|
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
)
|
|
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
if (
|
|
labels is not None
|
|
and decoder_input_ids is None
|
|
and decoder_inputs_embeds is None
|
|
):
|
|
# get decoder inputs from shifting lm labels to the right
|
|
decoder_input_ids = self._shift_right(labels)
|
|
|
|
# Decode
|
|
decoder_outputs = self.decoder(
|
|
input_ids=decoder_input_ids,
|
|
attention_mask=decoder_attention_mask,
|
|
inputs_embeds=decoder_inputs_embeds,
|
|
past_key_values=past_key_values,
|
|
encoder_hidden_states=hidden_states,
|
|
encoder_attention_mask=attention_mask,
|
|
head_mask=decoder_head_mask,
|
|
cross_attn_head_mask=cross_attn_head_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = decoder_outputs[0]
|
|
|
|
if self.config.tie_word_embeddings:
|
|
# Rescale output before projecting on vocab
|
|
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
|
|
|
lm_logits = self.lm_head(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
|
# move labels to correct device to enable PP
|
|
labels = labels.to(lm_logits.device)
|
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
|
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
|
|
|
if not return_dict:
|
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return Seq2SeqLMOutput(
|
|
loss=loss,
|
|
logits=lm_logits,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
decoder_head_mask=None,
|
|
decoder_attention_mask=None,
|
|
cross_attn_head_mask=None,
|
|
use_cache=None,
|
|
encoder_outputs=None,
|
|
**kwargs,
|
|
):
|
|
# cut decoder_input_ids if past is used
|
|
if past_key_values is not None:
|
|
input_ids = input_ids[:, -1:]
|
|
|
|
return {
|
|
"decoder_input_ids": input_ids,
|
|
"past_key_values": past_key_values,
|
|
"encoder_outputs": encoder_outputs,
|
|
"attention_mask": attention_mask,
|
|
"head_mask": head_mask,
|
|
"decoder_head_mask": decoder_head_mask,
|
|
"decoder_attention_mask": decoder_attention_mask,
|
|
"cross_attn_head_mask": cross_attn_head_mask,
|
|
"use_cache": use_cache,
|
|
}
|
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
|
return self._shift_right(labels)
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx):
|
|
# if decoder past is not included in output
|
|
# speedy decoding is disabled and no need to reorder
|
|
if past_key_values is None:
|
|
logger.warning(
|
|
"You might want to consider setting `use_cache=True` to speed up decoding"
|
|
)
|
|
return past_key_values
|
|
|
|
reordered_decoder_past = ()
|
|
for layer_past_states in past_key_values:
|
|
# get the correct batch idx from layer past batch dim
|
|
# batch dim of `past` is at 2nd position
|
|
reordered_layer_past_states = ()
|
|
for layer_past_state in layer_past_states:
|
|
# need to set correct `past` for each of the four key / value states
|
|
reordered_layer_past_states = reordered_layer_past_states + (
|
|
layer_past_state.index_select(
|
|
0, beam_idx.to(layer_past_state.device)
|
|
),
|
|
)
|
|
|
|
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
|
assert len(reordered_layer_past_states) == len(layer_past_states)
|
|
|
|
reordered_decoder_past = reordered_decoder_past + (
|
|
reordered_layer_past_states,
|
|
)
|
|
return reordered_decoder_past
|