mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
* Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com>
858 lines
34 KiB
Python
858 lines
34 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch OPT model."""
|
|
import random
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
CausalLMOutputWithPast,
|
|
)
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers import OPTConfig
|
|
from text_generation_server.layers import (
|
|
FastLinear,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
SpeculativeHead,
|
|
)
|
|
|
|
EPS = 1e-5
|
|
|
|
|
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
def _make_causal_mask(
|
|
input_ids_shape: torch.Size,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
past_key_values_length: int = 0,
|
|
):
|
|
"""
|
|
Make causal mask used for bi-directional self-attention.
|
|
"""
|
|
bsz, tgt_len = input_ids_shape
|
|
mask = torch.full(
|
|
(tgt_len, tgt_len),
|
|
torch.tensor(torch.finfo(dtype).min, device=device),
|
|
device=device,
|
|
)
|
|
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
mask = mask.to(dtype)
|
|
|
|
if past_key_values_length > 0:
|
|
mask = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
),
|
|
mask,
|
|
],
|
|
dim=-1,
|
|
)
|
|
return mask[None, None, :, :].expand(
|
|
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
)
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
"""
|
|
bsz, src_len = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
return inverted_mask.masked_fill(
|
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
)
|
|
|
|
|
|
class OPTLearnedPositionalEmbedding(nn.Module):
|
|
"""
|
|
This module learns positional embeddings up to a fixed maximum size.
|
|
"""
|
|
|
|
def __init__(self, prefix: str, weights):
|
|
super().__init__()
|
|
self.offset = 2
|
|
self.weight = nn.Parameter(
|
|
weights.get_tensor(
|
|
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
|
|
):
|
|
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
|
attention_mask = attention_mask.long()
|
|
|
|
# create positions depending on attention_mask
|
|
positions = (
|
|
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
|
|
).long() - 1
|
|
|
|
# cut positions if `past_key_values_length` is > 0
|
|
positions = positions[:, past_key_values_length:]
|
|
|
|
return torch.nn.functional.embedding(positions + self.offset, self.weight)
|
|
|
|
|
|
class OPTAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix,
|
|
weights,
|
|
is_decoder: bool = False,
|
|
bias: bool = True,
|
|
process_group=None,
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
num_heads = config.num_attention_heads
|
|
|
|
self.hidden_size = hidden_size
|
|
self.num_heads = num_heads
|
|
self.dropout = config.dropout
|
|
self.head_dim = hidden_size // num_heads
|
|
|
|
if (self.head_dim * num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.is_decoder = is_decoder
|
|
|
|
process_group = weights.process_group
|
|
if self.num_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.num_heads = self.num_heads // process_group.size()
|
|
self.hidden_size = self.hidden_size // process_group.size()
|
|
|
|
self.q_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
|
|
)
|
|
self.k_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
|
|
)
|
|
self.v_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
|
|
)
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
|
|
)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return (
|
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
# for the decoder
|
|
is_cross_attention = key_value_states is not None
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# get query proj
|
|
query_states = self.q_proj(hidden_states) * self.scaling
|
|
# get key, value proj
|
|
if is_cross_attention and past_key_value is not None:
|
|
# reuse k,v, cross_attentions
|
|
key_states = past_key_value[0]
|
|
value_states = past_key_value[1]
|
|
elif is_cross_attention:
|
|
# cross_attentions
|
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
elif past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
else:
|
|
# self_attention
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
if self.is_decoder:
|
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
# key/value_states (first "if" case)
|
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
past_key_value = (key_states, value_states)
|
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
key_states = key_states.view(*proj_shape)
|
|
value_states = value_states.view(*proj_shape)
|
|
|
|
src_len = key_states.size(1)
|
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = (
|
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
+ attention_mask
|
|
)
|
|
attn_weights = torch.max(
|
|
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
|
if attn_weights.dtype == torch.float16:
|
|
attn_weights = nn.functional.softmax(
|
|
attn_weights, dim=-1, dtype=torch.float32
|
|
).to(torch.float16)
|
|
else:
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
if layer_head_mask is not None:
|
|
if layer_head_mask.size() != (self.num_heads,):
|
|
raise ValueError(
|
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
f" {layer_head_mask.size()}"
|
|
)
|
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
)
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if output_attentions:
|
|
# this operation is a bit awkward, but it's required to
|
|
# make sure that attn_weights keeps its gradient.
|
|
# In order to do so, attn_weights have to be reshaped
|
|
# twice and have to be reused in the following
|
|
attn_weights_reshaped = attn_weights.view(
|
|
bsz, self.num_heads, tgt_len, src_len
|
|
)
|
|
attn_weights = attn_weights_reshaped.view(
|
|
bsz * self.num_heads, tgt_len, src_len
|
|
)
|
|
else:
|
|
attn_weights_reshaped = None
|
|
|
|
attn_probs = nn.functional.dropout(
|
|
attn_weights, p=self.dropout, training=self.training
|
|
)
|
|
|
|
attn_output = torch.bmm(attn_probs, value_states)
|
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
|
|
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
# partitioned aross GPUs when using tensor-parallelism.
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights_reshaped, past_key_value
|
|
|
|
|
|
class OPTDecoderLayer(nn.Module):
|
|
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
|
|
super().__init__()
|
|
self.process_group = weights.process_group
|
|
self.hidden_size = config.hidden_size
|
|
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
|
|
self.self_attn = OPTAttention(
|
|
config,
|
|
prefix=f"{prefix}.self_attn",
|
|
weights=weights,
|
|
is_decoder=True,
|
|
bias=config.enable_bias,
|
|
)
|
|
self.do_layer_norm_before = config.do_layer_norm_before
|
|
self.dropout = config.dropout
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
|
|
self.self_attn_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
self.fc1 = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
|
|
)
|
|
self.fc2 = TensorParallelRowLinear.load(
|
|
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
|
|
)
|
|
self.final_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
) -> Tuple[
|
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
|
`(encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
"""
|
|
|
|
residual = hidden_states
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
past_key_value=past_key_value,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=self.training
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Fully Connected
|
|
hidden_states_shape = hidden_states.shape
|
|
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
|
residual = hidden_states
|
|
|
|
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
if self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=self.training
|
|
)
|
|
|
|
hidden_states = (residual + hidden_states).view(hidden_states_shape)
|
|
|
|
# 350m applies layer norm AFTER attention
|
|
if not self.do_layer_norm_before:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
class OPTPreTrainedModel(PreTrainedModel):
|
|
config_class = OPTConfig
|
|
|
|
|
|
class OPTDecoder(OPTPreTrainedModel):
|
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
|
super().__init__(config)
|
|
self.dropout = config.dropout
|
|
self.layerdrop = config.layerdrop
|
|
self.padding_idx = config.pad_token_id
|
|
self.max_target_positions = config.max_position_embeddings
|
|
self.vocab_size = config.vocab_size
|
|
|
|
prefix = prefix + "." if prefix else ""
|
|
|
|
self.embed_tokens = TensorParallelEmbedding(
|
|
prefix=f"{prefix}decoder.embed_tokens", weights=weights
|
|
)
|
|
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
|
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_out = FastLinear.load(
|
|
config,
|
|
prefix=f"{prefix}decoder.project_out",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
else:
|
|
self.project_out = None
|
|
|
|
if config.word_embed_proj_dim != config.hidden_size:
|
|
self.project_in = FastLinear.load(
|
|
config,
|
|
prefix=f"{prefix}decoder.project_in",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
else:
|
|
self.project_in = None
|
|
|
|
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
|
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
|
self.final_layer_norm = nn.LayerNorm.load(
|
|
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
|
|
)
|
|
else:
|
|
self.final_layer_norm = None
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
OPTDecoderLayer(layer_id, prefix, config, weights)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
|
def _prepare_decoder_attention_mask(
|
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
):
|
|
# create causal mask
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
combined_attention_mask = None
|
|
if input_shape[-1] > 1:
|
|
combined_attention_mask = _make_causal_mask(
|
|
input_shape,
|
|
inputs_embeds.dtype,
|
|
device=inputs_embeds.device,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
expanded_attn_mask = _expand_mask(
|
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
).to(inputs_embeds.device)
|
|
combined_attention_mask = (
|
|
expanded_attn_mask
|
|
if combined_attention_mask is None
|
|
else expanded_attn_mask + combined_attention_mask
|
|
)
|
|
|
|
return combined_attention_mask
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
provide it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
)
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError(
|
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
batch_size, seq_length = input_shape
|
|
past_key_values_length = (
|
|
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
)
|
|
# required mask seq length can be calculated via length of past
|
|
mask_seq_length = past_key_values_length + seq_length
|
|
|
|
# embed positions
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
batch_size, mask_seq_length, device=inputs_embeds.device
|
|
)
|
|
causal_attention_mask = self._prepare_decoder_attention_mask(
|
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
)
|
|
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
|
|
|
if self.project_in is not None:
|
|
inputs_embeds = self.project_in(inputs_embeds)
|
|
|
|
hidden_states = inputs_embeds + pos_embeds
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
# check if head_mask has a correct number of layers specified if desired
|
|
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
|
if attn_mask is not None:
|
|
if attn_mask.size()[0] != (len(self.layers)):
|
|
raise ValueError(
|
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
f" {head_mask.size()[0]}."
|
|
)
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
dropout_probability = random.uniform(0, 1)
|
|
if self.training and (dropout_probability < self.layerdrop):
|
|
continue
|
|
|
|
past_key_value = (
|
|
past_key_values[idx] if past_key_values is not None else None
|
|
)
|
|
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_attention_mask,
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
if self.final_layer_norm is not None:
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
if self.project_out is not None:
|
|
hidden_states = self.project_out(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class OPTModel(OPTPreTrainedModel):
|
|
def __init__(self, prefix: str, config: OPTConfig, weights):
|
|
super().__init__(config)
|
|
self.decoder = OPTDecoder(prefix, config, weights)
|
|
# Initialize weights and apply final processing
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
decoder_outputs = self.decoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
if not return_dict:
|
|
return decoder_outputs
|
|
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
hidden_states=decoder_outputs.hidden_states,
|
|
attentions=decoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
class OPTForCausalLM(OPTPreTrainedModel):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__(config)
|
|
|
|
self.model = OPTModel(prefix, config, weights)
|
|
|
|
self.lm_head = SpeculativeHead.load(
|
|
config,
|
|
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
|
|
weights=weights,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs = self.model.decoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
|
|
|
loss = None
|
|
|
|
return (
|
|
CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
),
|
|
speculative_logits,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
inputs_embeds=None,
|
|
**kwargs,
|
|
):
|
|
if past_key_values:
|
|
input_ids = input_ids[:, -1:]
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"attention_mask": attention_mask,
|
|
}
|
|
)
|
|
return model_inputs
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past_key_values, beam_idx):
|
|
reordered_past = ()
|
|
for layer_past in past_key_values:
|
|
reordered_past += (
|
|
tuple(
|
|
past_state.index_select(0, beam_idx) for past_state in layer_past
|
|
),
|
|
)
|
|
return reordered_past
|