enable dbrx remove some unused code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-03-19 03:16:41 -07:00
parent 2cde30de24
commit 2074d0516b
8 changed files with 12 additions and 4471 deletions

View File

@ -77,6 +77,11 @@ class PositionRotaryEmbedding(nn.Module):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.

View File

@ -286,16 +286,6 @@ class ModelType(enum.Enum):
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
OPT = {
"type": "opt",
"name": "Opt",
"url": "https://huggingface.co/facebook/opt-6.7b",
}
T5 = {
"type": "t5",
"name": "T5",
"url": "https://huggingface.co/google/flan-t5-xxl",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
@ -306,16 +296,6 @@ class ModelType(enum.Enum):
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
BLOOM = {
"type": "bloom",
"name": "Bloom",
"url": "https://huggingface.co/bigscience/bloom-560m",
}
MPT = {
"type": "mpt",
"name": "Mpt",
"url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",

View File

@ -43,9 +43,7 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
moe_kernels = None
from vllm_hpu_extension.ops import DynamicFusedMOE
class DbrxAttentionConfig(PretrainedConfig):
@ -497,19 +495,15 @@ class BlockSparseMoE(nn.Module):
self.process_group = weights.process_group
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
for i in range(self.num_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = moe_kernels.fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
out = self.hpu_fused_moe(x, router_logits, self.top_k)
# Reduce sum
if self.process_group.size() > 1:

View File

@ -1,796 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTNeoX model."""
from typing import Optional, Tuple, Union
import os
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
CUSTOM_KERNELS_ENABLED = False
if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try:
from custom_kernels import fused_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
# ??? TODO
# self.register_buffer(
# "bias",
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
# 1, 1, max_positions, max_positions
# ),
# )
# self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
)
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.inv_norm_factor = 1.0 / torch.sqrt(
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
)
self.dense = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
self,
hidden_states,
position_ids,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query, key, value = qkv.split(self.head_size, -1)
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
key_rot = key[..., : self.rotary_ndims]
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
query[..., : self.rotary_ndims] = query_rot
key[..., : self.rotary_ndims] = key_rot
if CUSTOM_KERNELS_ENABLED:
attn_output, present, attn_weights = fused_attention_cuda.forward(
query,
key,
value,
layer_past,
attention_mask,
head_mask,
self.inv_norm_factor,
self.num_attention_heads,
use_cache,
)
else:
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
device=key.device,
).expand(batch_size * num_attention_heads, query_length, key_length)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
@torch.jit.script
def rotary_forward(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
chunk_size = q.shape[-1] // 2
q1, q2 = q.split(chunk_size, -1)
q_rotated = torch.cat((-q2, q1), dim=-1)
k1, k2 = k.split(chunk_size, -1)
k_rotated = torch.cat((-k2, k1), dim=-1)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return q_embed, k_embed
class GPTNeoXMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = (
ACT2FN[config.hidden_act]
if "gelu_fast" not in config.hidden_act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
)
def forward(
self,
hidden_states,
position_ids,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
# Attention mask.
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
position_ids=position_ids,
attention_mask=causal_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, prefix: str, config, weights):
super().__init__(config)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits, speculative_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return (
CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past

View File

@ -1,864 +0,0 @@
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.layers import (
FastLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
EPS = 1e-5
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, prefix: str, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor(
f"{prefix if prefix else ''}decoder.embed_positions.weight"
)
)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return torch.nn.functional.embedding(positions + self.offset, self.weight)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
prefix,
weights,
is_decoder: bool = False,
bias: bool = True,
process_group=None,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.hidden_size = self.hidden_size // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
)
self.fc1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
)
self.fc2 = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
prefix = prefix + "." if prefix else ""
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config,
prefix=f"{prefix}decoder.project_out",
weights=weights,
bias=False,
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config,
prefix=f"{prefix}decoder.project_in",
weights=weights,
bias=False,
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(
layer_id,
prefix=f"{prefix}decoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix and any(s.startswith("model") for s in weights.routing.keys()):
prefix = "model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
loss = None
return (
CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past

View File

@ -1,336 +0,0 @@
# imlementation of the PhiModel and PhiForCausalLM classes
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
FastLinear,
)
# PhiConfig is the configuration class for the PhiModel.
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
freqs = t.matmul(inv_freq)
self.sin = freqs.sin()
self.cos = freqs.cos()
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
b_size, seqlen, three, _, _headdim = qkv.shape
if three != 3:
raise Exception("unexpected shape for qkv")
_, rotary_dim = self.cos.shape
rotary_dim = rotary_dim * 2
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q12 = torch.chunk(q_rot, 2, dim=-1)
k12 = torch.chunk(k_rot, 2, dim=-1)
q1, q2 = q12[0], q12[1]
k1, k2 = k12[0], k12[1]
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
q_rot = torch.cat(
[
q1 * c - q2 * s,
q1 * s + q2 * c,
],
dim=-1,
)
k_rot = torch.cat(
[
k1 * c - k2 * s,
k1 * s + k2 * c,
],
dim=-1,
)
q = torch.cat([q_rot, q_pass], dim=-1)
k = torch.cat([k_rot, k_pass], dim=-1)
v = qkv[:, :, 2]
return q, k, v
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
class PhiCausalLMHead(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.ln = nn.LayerNorm.load(
prefix="lm_head.ln",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.linear = SpeculativeHead.load(
config=config, prefix="lm_head.linear", weights=weights
)
def forward(self, hidden_states):
hidden_states = self.ln(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
self.op_size = config.n_embd
self.head_dim = int(config.n_embd / config.n_head)
self.num_heads = config.n_head
self.rotary_emb = RotaryEmbedding(
config.rotary_dim,
config.n_positions,
)
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
b_size, seq_len, _n_embd = hidden_states.shape
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
# if there is a kv_cache, then we need to concatenate
if past_kv_cache is not None:
prev_k, prev_v = past_kv_cache
k = torch.cat([prev_k, k], dim=1)
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
1,
)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = (
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
.transpose(1, 2)
.flatten(-2)
)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.n_inner = config.n_inner
self.fc1 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=False,
)
self.fc2 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=False,
)
self.activation = torch.nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(
hidden_states, kv_cache, attention_mask
)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = (
[None] * len(self.blocks) if past_key_values is None else past_key_values
)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(
hidden_states, past_key_values[index], mask
)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
return (
((loss,) + (logits,) + model_output[1:])
if loss is not None
else (logits,) + model_output[1:]
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output[1],
hidden_states=None,
attentions=None,
)