mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
1206 lines
46 KiB
Python
1206 lines
46 KiB
Python
"""A simple, flexible implementation of a GPT model.
|
|
|
|
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
"""
|
|
|
|
import math
|
|
import os
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
CausalLMOutputWithPast,
|
|
)
|
|
from einops import rearrange
|
|
from packaging import version
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelEmbedding,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
TensorParallelHead,
|
|
get_linear,
|
|
)
|
|
|
|
EPS = 1e-5
|
|
|
|
|
|
def load_col(config, prefix, weights, bias):
|
|
assert config.quantize != "gptq", NotImplementedError
|
|
slice_ = weights._get_slice(f"{prefix}.weight")
|
|
rank = weights.process_group.rank()
|
|
size = weights.process_group.size()
|
|
|
|
h3, h = slice_.get_shape()
|
|
block_size = h // size
|
|
|
|
q_part = slice_[rank * block_size : (rank + 1) * block_size]
|
|
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
|
|
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
|
|
|
|
weight = torch.cat([q_part, k_part, v_part], dim=0)
|
|
if weight.dtype != torch.int32:
|
|
weight = weight.to(dtype=weights.dtype)
|
|
weight = weight.to(device=weights.device)
|
|
|
|
if bias:
|
|
bias_slice_ = weights._get_slice(f"{prefix}.bias")
|
|
bias_rank = weights.process_group.rank()
|
|
bias_size = weights.process_group.size()
|
|
|
|
bias_h = bias_slice_.get_shape()
|
|
bias_h = bias_h[0]
|
|
bias_block_size = bias_h // bias_size
|
|
|
|
bias_q_part = bias_slice_[
|
|
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
|
|
]
|
|
bias_k_part = bias_slice_[
|
|
bias_h
|
|
+ bias_rank * bias_block_size : bias_h
|
|
+ (bias_rank + 1) * bias_block_size
|
|
]
|
|
bias_v_part = bias_slice_[
|
|
2 * bias_h
|
|
+ bias_rank * bias_block_size : 2 * bias_h
|
|
+ (bias_rank + 1) * bias_block_size
|
|
]
|
|
|
|
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
|
if bias.dtype != torch.int32:
|
|
bias = bias.to(dtype=weights.dtype)
|
|
bias = bias.to(device=weights.device)
|
|
else:
|
|
bias = None
|
|
linear = get_linear(weight, bias, config.quantize)
|
|
return TensorParallelColumnLinear(linear)
|
|
|
|
|
|
def _reset_is_causal(
|
|
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
|
):
|
|
if original_is_causal and num_query_tokens != num_key_tokens:
|
|
if num_query_tokens != 1:
|
|
raise NotImplementedError(
|
|
"MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
|
|
)
|
|
else:
|
|
return False
|
|
return original_is_causal
|
|
|
|
|
|
def scaled_multihead_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
|
kv_n_heads = 1 if multiquery else n_heads
|
|
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
|
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
k = torch.cat([past_key_value[0], k], dim=3)
|
|
v = torch.cat([past_key_value[1], v], dim=2)
|
|
past_key_value = (k, v)
|
|
(b, _, s_q, d) = q.shape
|
|
s_k = k.size(-1)
|
|
attn_weight = q.matmul(k) * softmax_scale
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - s_q)
|
|
_s_k = max(0, attn_bias.size(3) - s_k)
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if (
|
|
attn_bias.size(-1) != 1
|
|
and attn_bias.size(-1) != s_k
|
|
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
|
|
):
|
|
raise RuntimeError(
|
|
f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
|
|
)
|
|
attn_weight = attn_weight + attn_bias
|
|
min_val = torch.finfo(q.dtype).min
|
|
if key_padding_mask is not None:
|
|
if attn_bias is not None:
|
|
warnings.warn(
|
|
"Propogating key_padding_mask to the attention module "
|
|
+ "and applying it within the attention module can cause "
|
|
+ "unneccessary computation/memory usage. Consider integrating "
|
|
+ "into attn_bias once and passing that to each attention "
|
|
+ "module instead."
|
|
)
|
|
attn_weight = attn_weight.masked_fill(
|
|
~key_padding_mask.view((b, 1, 1, s_k)), min_val
|
|
)
|
|
if is_causal and (not q.size(2) == 1):
|
|
s = max(s_q, s_k)
|
|
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
|
causal_mask = causal_mask.tril()
|
|
causal_mask = causal_mask.to(torch.bool)
|
|
causal_mask = ~causal_mask
|
|
causal_mask = causal_mask[-s_q:, -s_k:]
|
|
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
if dropout_p:
|
|
attn_weight = torch.nn.functional.dropout(
|
|
attn_weight, p=dropout_p, training=training, inplace=True
|
|
)
|
|
out = attn_weight.to(v.dtype).matmul(v)
|
|
out = rearrange(out, "b h s d -> b s (h d)")
|
|
if needs_weights:
|
|
return (out, attn_weight, past_key_value)
|
|
return (out, None, past_key_value)
|
|
|
|
|
|
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
for tensor in tensors:
|
|
if tensor.dtype not in valid_dtypes:
|
|
raise TypeError(
|
|
f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
|
|
)
|
|
if not tensor.is_cuda:
|
|
raise TypeError(
|
|
f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
|
|
)
|
|
|
|
|
|
def flash_attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
try:
|
|
from flash_attn import bert_padding, flash_attn_interface
|
|
except:
|
|
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
|
check_valid_inputs(query, key, value)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
key = torch.cat([past_key_value[0], key], dim=1)
|
|
value = torch.cat([past_key_value[1], value], dim=1)
|
|
past_key_value = (key, value)
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
|
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if attn_bias is not None:
|
|
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
|
(batch_size, seqlen) = query.shape[:2]
|
|
if key_padding_mask is None:
|
|
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
|
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
|
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
|
|
query, query_padding_mask
|
|
)
|
|
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
|
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
|
|
key, key_padding_mask
|
|
)
|
|
key_unpad = rearrange(
|
|
key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
|
)
|
|
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
|
value_unpad = rearrange(
|
|
value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
|
)
|
|
if multiquery:
|
|
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
|
value_unpad = value_unpad.expand(
|
|
value_unpad.size(0), n_heads, value_unpad.size(-1)
|
|
)
|
|
dropout_p = dropout_p if training else 0.0
|
|
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
|
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
|
query_unpad,
|
|
key_unpad,
|
|
value_unpad,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale=softmax_scale,
|
|
causal=reset_is_causal,
|
|
return_attn_probs=needs_weights,
|
|
)
|
|
output = bert_padding.pad_input(
|
|
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
|
|
)
|
|
return (output, None, past_key_value)
|
|
|
|
|
|
def triton_flash_attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
try:
|
|
from .flash_attn_triton import flash_attn_func
|
|
except:
|
|
_installed = False
|
|
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
|
_installed = True
|
|
try:
|
|
from flash_attn.flash_attn_triton import flash_attn_func
|
|
except:
|
|
_installed = False
|
|
if not _installed:
|
|
raise RuntimeError(
|
|
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
|
|
)
|
|
check_valid_inputs(query, key, value)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
key = torch.cat([past_key_value[0], key], dim=1)
|
|
value = torch.cat([past_key_value[1], value], dim=1)
|
|
past_key_value = (key, value)
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
|
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if dropout_p:
|
|
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
|
if needs_weights:
|
|
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
|
if key_padding_mask is not None:
|
|
warnings.warn(
|
|
"Propagating key_padding_mask to the attention module "
|
|
+ "and applying it within the attention module can cause "
|
|
+ "unnecessary computation/memory usage. Consider integrating "
|
|
+ "into attn_bias once and passing that to each attention "
|
|
+ "module instead."
|
|
)
|
|
(b_size, s_k) = key_padding_mask.shape[:2]
|
|
if attn_bias is None:
|
|
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
|
attn_bias = attn_bias.masked_fill(
|
|
~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
|
|
)
|
|
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
|
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
|
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
|
if multiquery:
|
|
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
|
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
|
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
|
attn_output = flash_attn_func(
|
|
query, key, value, attn_bias, reset_is_causal, softmax_scale
|
|
)
|
|
output = attn_output.view(*attn_output.shape[:2], -1)
|
|
return (output, None, past_key_value)
|
|
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
"""Multi-head self attention.
|
|
|
|
Using torch or triton attention implementation enables user to also use
|
|
additive bias.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix,
|
|
weights,
|
|
):
|
|
super().__init__()
|
|
attn_impl = config.attn_config["attn_impl"]
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
|
self.qk_ln = config.attn_config["qk_ln"]
|
|
self.d_model = config.d_model
|
|
d_model = config.d_model
|
|
self.n_heads = config.n_heads
|
|
self.softmax_scale = config.attn_config["softmax_scale"]
|
|
if self.softmax_scale is None:
|
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
|
|
|
if self.n_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.n_heads = self.n_heads // weights.process_group.size()
|
|
self.Wqkv = load_col(
|
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
)
|
|
if self.qk_ln:
|
|
bias = not config.no_bias
|
|
hidden_size = config.d_model
|
|
head_dim = hidden_size // self.n_heads
|
|
|
|
self.q_ln = LPLayerNorm(
|
|
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
|
|
)
|
|
self.k_ln = LPLayerNorm(
|
|
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
|
|
)
|
|
if self.attn_impl == "flash":
|
|
self.attn_fn = flash_attn_fn
|
|
elif self.attn_impl == "triton":
|
|
self.attn_fn = triton_flash_attn_fn
|
|
elif self.attn_impl == "torch":
|
|
self.attn_fn = scaled_multihead_dot_product_attention
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
past_key_value=None,
|
|
attn_bias=None,
|
|
attention_mask=None,
|
|
is_causal=True,
|
|
needs_weights=False,
|
|
):
|
|
qkv = self.Wqkv(x)
|
|
if self.clip_qkv:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
(query, key, value) = qkv.chunk(3, dim=2)
|
|
|
|
key_padding_mask = attention_mask
|
|
if self.qk_ln:
|
|
dtype = query.dtype
|
|
query = self.q_ln(query).to(dtype)
|
|
key = self.k_ln(key).to(dtype)
|
|
(context, attn_weights, past_key_value) = self.attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
self.n_heads,
|
|
past_key_value=past_key_value,
|
|
softmax_scale=self.softmax_scale,
|
|
attn_bias=attn_bias,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
dropout_p=self.attn_dropout_p,
|
|
training=self.training,
|
|
needs_weights=needs_weights,
|
|
)
|
|
out = self.out_proj(context)
|
|
return (out, attn_weights, past_key_value)
|
|
|
|
|
|
class MultiQueryAttention(nn.Module):
|
|
"""Multi-Query self attention.
|
|
|
|
Using torch or triton attention implementation enables user to also use
|
|
additive bias.
|
|
"""
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
attn_impl = config.attn_config["attn_impl"]
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
|
self.qk_ln = config.attn_config["qk_ln"]
|
|
self.d_model = config.d_model
|
|
d_model = config.d_model
|
|
self.n_heads = config.n_heads
|
|
self.softmax_scale = config.attn_config["softmax_scale"]
|
|
if self.softmax_scale is None:
|
|
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
|
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
|
self.Wqkv = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
)
|
|
fuse_splits = (d_model, d_model + self.head_dim)
|
|
if self.qk_ln:
|
|
raise NotImplementedError("qk_ln not supported")
|
|
if self.attn_impl == "flash":
|
|
self.attn_fn = flash_attn_fn
|
|
elif self.attn_impl == "triton":
|
|
self.attn_fn = triton_flash_attn_fn
|
|
if verbose:
|
|
warnings.warn(
|
|
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
|
+ "it uses more memory. When training larger models this can trigger "
|
|
+ "alloc retries which hurts performance. If encountered, we recommend "
|
|
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
|
)
|
|
elif self.attn_impl == "torch":
|
|
self.attn_fn = scaled_multihead_dot_product_attention
|
|
if torch.cuda.is_available() and verbose:
|
|
warnings.warn(
|
|
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
|
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
|
+ "we recommend using `attn_impl: triton`."
|
|
)
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
# self.out_proj._is_residual = True
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
past_key_value=None,
|
|
attn_bias=None,
|
|
attention_mask=None,
|
|
is_causal=True,
|
|
needs_weights=False,
|
|
):
|
|
qkv = self.Wqkv(x)
|
|
if self.clip_qkv:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
(query, key, value) = qkv.split(
|
|
[self.d_model, self.head_dim, self.head_dim], dim=2
|
|
)
|
|
key_padding_mask = attention_mask
|
|
if self.qk_ln:
|
|
dtype = query.dtype
|
|
query = self.q_ln(query).to(dtype)
|
|
key = self.k_ln(key).to(dtype)
|
|
(context, attn_weights, past_key_value) = self.attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
self.n_heads,
|
|
past_key_value=past_key_value,
|
|
softmax_scale=self.softmax_scale,
|
|
attn_bias=attn_bias,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
dropout_p=self.attn_dropout_p,
|
|
training=self.training,
|
|
needs_weights=needs_weights,
|
|
multiquery=True,
|
|
)
|
|
return (self.out_proj(context), attn_weights, past_key_value)
|
|
|
|
|
|
def attn_bias_shape(
|
|
attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
|
|
):
|
|
if attn_impl == "flash":
|
|
return None
|
|
elif attn_impl in ["torch", "triton"]:
|
|
if alibi:
|
|
if (prefix_lm or not causal) or use_sequence_id:
|
|
return (1, n_heads, seq_len, seq_len)
|
|
return (1, n_heads, 1, seq_len)
|
|
elif prefix_lm or use_sequence_id:
|
|
return (1, 1, seq_len, seq_len)
|
|
return None
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
|
|
|
|
def build_attn_bias(
|
|
attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
|
|
):
|
|
if attn_impl == "flash":
|
|
return None
|
|
elif attn_impl in ["torch", "triton"]:
|
|
if alibi:
|
|
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
|
attn_bias = attn_bias.add(
|
|
build_alibi_bias(
|
|
n_heads,
|
|
seq_len,
|
|
full=not causal,
|
|
alibi_bias_max=alibi_bias_max,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
return attn_bias
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
|
|
|
|
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
|
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
|
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
|
m = m.mul(alibi_bias_max / _n_heads)
|
|
slopes = 1.0 / torch.pow(2, m)
|
|
if _n_heads != n_heads:
|
|
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
|
return slopes.view(1, n_heads, 1, 1)
|
|
|
|
|
|
def build_alibi_bias(
|
|
n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
|
|
):
|
|
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
|
|
1, 1, 1, seq_len
|
|
)
|
|
if full:
|
|
alibi_bias = alibi_bias - torch.arange(
|
|
1 - seq_len, 1, dtype=torch.int32, device=device
|
|
).view(1, 1, seq_len, 1)
|
|
alibi_bias = alibi_bias.abs().mul(-1)
|
|
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
|
alibi_bias = alibi_bias * slopes
|
|
return alibi_bias.to(dtype=dtype)
|
|
|
|
|
|
ATTN_CLASS_REGISTRY = {
|
|
"multihead_attention": MultiheadAttention,
|
|
"multiquery_attention": MultiQueryAttention,
|
|
}
|
|
|
|
"""GPT Blocks used for the GPT Model."""
|
|
|
|
|
|
class MPTMLP(nn.Module):
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
# self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
|
self.up_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias
|
|
)
|
|
self.act = nn.GELU(approximate="none")
|
|
# self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
|
self.down_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.down_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
# self.down_proj._is_residual = True
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act(self.up_proj(x)))
|
|
|
|
|
|
class MPTBlock(nn.Module):
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
self.prefix = prefix
|
|
if config.attn_config["attn_type"] != "multihead_attention":
|
|
raise NotImplementedError(
|
|
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
|
)
|
|
resid_pdrop = config.resid_pdrop
|
|
if config.no_bias:
|
|
self.norm_1 = nn.LayerNorm.load_no_bias(
|
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
|
)
|
|
self.norm_2 = nn.LayerNorm.load_no_bias(
|
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
|
)
|
|
else:
|
|
self.norm_1 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
|
)
|
|
self.norm_2 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
|
)
|
|
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
|
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
|
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
attn_bias: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
is_causal: bool = True,
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
|
a = self.norm_1(x)
|
|
(b, attn_weights, past_key_value) = self.attn(
|
|
a,
|
|
past_key_value=past_key_value,
|
|
attn_bias=attn_bias,
|
|
attention_mask=attention_mask,
|
|
is_causal=is_causal,
|
|
)
|
|
x = x + self.resid_attn_dropout(b)
|
|
m = self.norm_2(x)
|
|
n = self.ffn(m)
|
|
x = x + self.resid_ffn_dropout(n)
|
|
return (x, attn_weights, past_key_value)
|
|
|
|
|
|
def _cast_if_autocast_enabled(tensor):
|
|
if torch.is_autocast_enabled():
|
|
if tensor.device.type == "cuda":
|
|
dtype = torch.get_autocast_gpu_dtype()
|
|
elif tensor.device.type == "cpu":
|
|
dtype = torch.get_autocast_cpu_dtype()
|
|
else:
|
|
raise NotImplementedError()
|
|
return tensor.to(dtype=dtype)
|
|
return tensor
|
|
|
|
|
|
class LPLayerNorm(torch.nn.LayerNorm):
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=1e-05,
|
|
elementwise_affine=True,
|
|
device=None,
|
|
dtype=None,
|
|
bias: Optional[bool] = True,
|
|
prefix=None,
|
|
weights=None,
|
|
):
|
|
super().__init__(
|
|
normalized_shape=normalized_shape,
|
|
eps=eps,
|
|
elementwise_affine=elementwise_affine,
|
|
device=device,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
)
|
|
if weights is not None:
|
|
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
|
|
if bias:
|
|
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
|
self.normalized_shape = self.weight.shape
|
|
|
|
def forward(self, x):
|
|
module_device = x.device
|
|
downcast_x = _cast_if_autocast_enabled(x)
|
|
downcast_weight = (
|
|
_cast_if_autocast_enabled(self.weight)
|
|
if self.weight is not None
|
|
else self.weight
|
|
)
|
|
downcast_bias = (
|
|
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
|
)
|
|
with torch.autocast(enabled=False, device_type=module_device.type):
|
|
return torch.nn.functional.layer_norm(
|
|
downcast_x,
|
|
self.normalized_shape,
|
|
downcast_weight,
|
|
downcast_bias,
|
|
self.eps,
|
|
)
|
|
|
|
|
|
def rms_norm(x, weight=None, eps=1e-05):
|
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
|
if weight is not None:
|
|
return output * weight
|
|
return output
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(
|
|
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
if weight:
|
|
self.weight = torch.nn.Parameter(
|
|
torch.ones(normalized_shape, dtype=dtype, device=device)
|
|
)
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
|
|
def forward(self, x):
|
|
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
|
|
|
|
|
class LPRMSNorm(RMSNorm):
|
|
def __init__(
|
|
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
|
):
|
|
super().__init__(
|
|
normalized_shape=normalized_shape,
|
|
eps=eps,
|
|
weight=weight,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def forward(self, x):
|
|
downcast_x = _cast_if_autocast_enabled(x)
|
|
downcast_weight = (
|
|
_cast_if_autocast_enabled(self.weight)
|
|
if self.weight is not None
|
|
else self.weight
|
|
)
|
|
with torch.autocast(enabled=False, device_type=x.device.type):
|
|
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
|
|
|
|
|
NORM_CLASS_REGISTRY = {
|
|
"layernorm": torch.nn.LayerNorm,
|
|
"low_precision_layernorm": LPLayerNorm,
|
|
"rmsnorm": RMSNorm,
|
|
"low_precision_rmsnorm": LPRMSNorm,
|
|
}
|
|
|
|
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
|
|
|
|
class MPTPreTrainedModel(PreTrainedModel):
|
|
base_model_prefix = "model"
|
|
_no_split_modules = ["MPTBlock"]
|
|
|
|
|
|
class MPTModel(MPTPreTrainedModel):
|
|
def __init__(self, config, weights):
|
|
# config._validate_config()
|
|
super().__init__(config)
|
|
self.world_size = weights.process_group.size()
|
|
self.rank = weights.process_group.rank()
|
|
self.n_heads = config.n_heads
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.prefix_lm = config.attn_config["prefix_lm"]
|
|
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
|
self.alibi = config.attn_config["alibi"]
|
|
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
|
if config.init_device == "mixed":
|
|
if dist.get_local_rank() == 0:
|
|
config.init_device = "cpu"
|
|
else:
|
|
config.init_device = "meta"
|
|
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
|
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
|
raise NotImplementedError(
|
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
|
)
|
|
if config.norm_type.lower() != "low_precision_layernorm":
|
|
raise NotImplementedError(
|
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
|
)
|
|
|
|
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
|
|
|
if not self.alibi:
|
|
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
|
for i in range(config.n_layers)
|
|
]
|
|
)
|
|
if config.no_bias:
|
|
self.norm_f = nn.LayerNorm.load_no_bias(
|
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
|
)
|
|
else:
|
|
self.norm_f = nn.LayerNorm.load(
|
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
|
)
|
|
self.is_causal = not self.prefix_lm
|
|
self._attn_bias_initialized = False
|
|
self.attn_bias = None
|
|
self.attn_bias_shape = attn_bias_shape(
|
|
self.attn_impl,
|
|
config.n_heads,
|
|
config.max_seq_len,
|
|
self.alibi,
|
|
prefix_lm=self.prefix_lm,
|
|
causal=self.is_causal,
|
|
use_sequence_id=self.attn_uses_sequence_id,
|
|
)
|
|
if config.no_bias:
|
|
for module in self.modules():
|
|
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
|
if config.verbose:
|
|
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
|
module.register_parameter("bias", None)
|
|
if hasattr(self.config, "verbose"):
|
|
if config.verbose and config.verbose > 2:
|
|
print(self)
|
|
if "verbose" not in self.config.init_config:
|
|
self.config.init_config["verbose"] = self.config.verbose
|
|
if self.config.init_config["verbose"] > 1:
|
|
init_fn_name = self.config.init_config["name"]
|
|
warnings.warn(f"Using {init_fn_name} initialization.")
|
|
|
|
@torch.no_grad()
|
|
def _attn_bias(
|
|
self,
|
|
device,
|
|
dtype,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
):
|
|
if not self._attn_bias_initialized:
|
|
if self.attn_bias_shape:
|
|
self.attn_bias = torch.zeros(
|
|
self.attn_bias_shape, device=device, dtype=dtype
|
|
)
|
|
self.attn_bias = build_attn_bias(
|
|
self.attn_impl,
|
|
self.attn_bias,
|
|
self.config.n_heads,
|
|
self.config.max_seq_len,
|
|
causal=self.is_causal,
|
|
alibi=self.alibi,
|
|
alibi_bias_max=self.alibi_bias_max,
|
|
)
|
|
assert self.n_heads % self.world_size == 0
|
|
block_size = self.n_heads // self.world_size
|
|
self.attn_bias = self.attn_bias[
|
|
:, self.rank * block_size : (self.rank + 1) * block_size
|
|
]
|
|
self._attn_bias_initialized = True
|
|
if self.attn_impl == "flash":
|
|
return (self.attn_bias, attention_mask)
|
|
if self.attn_bias is not None:
|
|
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
|
attn_bias = self.attn_bias
|
|
if self.prefix_lm:
|
|
assert isinstance(attn_bias, torch.Tensor)
|
|
assert isinstance(prefix_mask, torch.Tensor)
|
|
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
|
if self.attn_uses_sequence_id and sequence_id is not None:
|
|
assert isinstance(attn_bias, torch.Tensor)
|
|
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
|
if attention_mask is not None:
|
|
s_k = attention_mask.shape[-1]
|
|
if attn_bias is None:
|
|
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
|
else:
|
|
_s_k = max(0, attn_bias.size(-1) - s_k)
|
|
attn_bias = attn_bias[:, :, :, _s_k:]
|
|
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
|
raise ValueError(
|
|
f"attention_mask shape={attention_mask.shape} "
|
|
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
|
)
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(
|
|
~attention_mask.view(-1, 1, 1, s_k), min_val
|
|
)
|
|
return (attn_bias, None)
|
|
|
|
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
|
(s_k, s_q) = attn_bias.shape[-2:]
|
|
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
|
raise ValueError(
|
|
"attn_bias does not match the expected shape. "
|
|
+ f"The last two dimensions should both be {self.config.max_length} "
|
|
+ f"but are {s_k} and {s_q}."
|
|
)
|
|
seq_len = prefix_mask.shape[-1]
|
|
if seq_len > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
|
)
|
|
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
|
causal = torch.tril(
|
|
torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
|
|
).view(1, 1, seq_len, seq_len)
|
|
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
|
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
|
return attn_bias
|
|
|
|
def _apply_sequence_id(
|
|
self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
|
|
):
|
|
seq_len = sequence_id.shape[-1]
|
|
if seq_len > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
|
)
|
|
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
|
cannot_attend = torch.logical_not(
|
|
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
|
).unsqueeze(1)
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
|
return attn_bias
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
):
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.return_dict
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.bool()
|
|
if prefix_mask is not None:
|
|
prefix_mask = prefix_mask.bool()
|
|
if not return_dict:
|
|
raise NotImplementedError(
|
|
"return_dict False is not implemented yet for MPT"
|
|
)
|
|
if output_attentions:
|
|
if self.attn_impl != "torch":
|
|
raise NotImplementedError(
|
|
"output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
|
|
)
|
|
if (
|
|
attention_mask is not None
|
|
and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
|
and self.training
|
|
):
|
|
raise NotImplementedError(
|
|
"MPT does not support training with left padding."
|
|
)
|
|
if self.prefix_lm and prefix_mask is None:
|
|
raise ValueError(
|
|
"prefix_mask is a required argument when MPT is configured with prefix_lm=True."
|
|
)
|
|
if self.training:
|
|
if self.attn_uses_sequence_id and sequence_id is None:
|
|
raise ValueError(
|
|
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
|
+ "and the model is in train mode."
|
|
)
|
|
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
|
warnings.warn(
|
|
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
|
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
|
)
|
|
S = input_ids.size(1)
|
|
assert (
|
|
S <= self.config.max_seq_len
|
|
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
|
tok_emb = self.wte(input_ids)
|
|
if self.alibi:
|
|
x = tok_emb
|
|
else:
|
|
past_position = 0
|
|
if past_key_values is not None:
|
|
if len(past_key_values) != self.config.n_layers:
|
|
raise ValueError(
|
|
f"past_key_values must provide a past_key_value for each attention "
|
|
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
|
)
|
|
past_position = past_key_values[0][0].size(1)
|
|
if self.attn_impl == "torch":
|
|
past_position = past_key_values[0][0].size(3)
|
|
if S + past_position > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
|
)
|
|
pos = torch.arange(
|
|
past_position,
|
|
S + past_position,
|
|
dtype=torch.long,
|
|
device=input_ids.device,
|
|
).unsqueeze(0)
|
|
if attention_mask is not None:
|
|
pos = torch.clamp(
|
|
pos
|
|
- torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
|
|
:, past_position:
|
|
],
|
|
min=0,
|
|
)
|
|
pos_emb = self.wpe(pos)
|
|
x = tok_emb + pos_emb
|
|
(attn_bias, attention_mask) = self._attn_bias(
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
attention_mask=attention_mask,
|
|
prefix_mask=prefix_mask,
|
|
sequence_id=sequence_id,
|
|
)
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = [() for _ in range(self.config.n_layers)]
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
for b_idx, block in enumerate(self.blocks):
|
|
if output_hidden_states:
|
|
assert all_hidden_states is not None
|
|
all_hidden_states = all_hidden_states + (x,)
|
|
past_key_value = (
|
|
past_key_values[b_idx] if past_key_values is not None else None
|
|
)
|
|
(x, attn_weights, past_key_value) = block(
|
|
x,
|
|
past_key_value=past_key_value,
|
|
attn_bias=attn_bias,
|
|
attention_mask=attention_mask,
|
|
is_causal=self.is_causal,
|
|
)
|
|
if past_key_values is not None:
|
|
past_key_values[b_idx] = past_key_value
|
|
if output_attentions:
|
|
assert all_self_attns is not None
|
|
all_self_attns = all_self_attns + (attn_weights,)
|
|
x = self.norm_f(x)
|
|
if output_hidden_states:
|
|
assert all_hidden_states is not None
|
|
all_hidden_states = all_hidden_states + (x,)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=x,
|
|
past_key_values=past_key_values,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class MPTForCausalLM(MPTPreTrainedModel):
|
|
def __init__(self, config, weights):
|
|
super().__init__(config)
|
|
if not config.tie_word_embeddings:
|
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
|
self.transformer = MPTModel(config, weights)
|
|
self.lm_head = TensorParallelHead.load(
|
|
config, prefix="transformer.wte", weights=weights
|
|
)
|
|
self.logit_scale = None
|
|
if config.logit_scale is not None:
|
|
logit_scale = config.logit_scale
|
|
if isinstance(logit_scale, str):
|
|
if logit_scale == "inv_sqrt_d_model":
|
|
logit_scale = 1 / math.sqrt(config.d_model)
|
|
else:
|
|
raise ValueError(
|
|
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
|
)
|
|
self.logit_scale = logit_scale
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
):
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.return_dict
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
outputs = self.transformer(
|
|
input_ids=input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
prefix_mask=prefix_mask,
|
|
sequence_id=sequence_id,
|
|
return_dict=return_dict,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
use_cache=use_cache,
|
|
)
|
|
logits = self.lm_head(outputs.last_hidden_state)
|
|
if self.logit_scale is not None:
|
|
if self.logit_scale == 0:
|
|
warnings.warn(
|
|
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
|
)
|
|
logits *= self.logit_scale
|
|
loss = None
|
|
if labels is not None:
|
|
labels = torch.roll(labels, shifts=-1)
|
|
labels[:, -1] = -100
|
|
loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
|
)
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
|
):
|
|
if inputs_embeds is not None:
|
|
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
|
attention_mask = kwargs["attention_mask"].bool()
|
|
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
|
raise NotImplementedError(
|
|
"MPT does not support generation with right padding."
|
|
)
|
|
if self.transformer.attn_uses_sequence_id and self.training:
|
|
sequence_id = torch.zeros_like(input_ids[:1])
|
|
else:
|
|
sequence_id = None
|
|
if past_key_values is not None:
|
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
if self.transformer.prefix_lm:
|
|
prefix_mask = torch.ones_like(attention_mask)
|
|
if kwargs.get("use_cache") == False:
|
|
raise NotImplementedError(
|
|
"MPT with prefix_lm=True does not support use_cache=False."
|
|
)
|
|
else:
|
|
prefix_mask = None
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"prefix_mask": prefix_mask,
|
|
"sequence_id": sequence_id,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache", True),
|
|
}
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past_key_values, beam_idx):
|
|
"""Used by HuggingFace generate when using beam search with kv-caching.
|
|
|
|
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
|
for an example in transformers.
|
|
"""
|
|
reordered_past = []
|
|
for layer_past in past_key_values:
|
|
reordered_past += [
|
|
tuple(
|
|
(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
|
)
|
|
]
|
|
return reordered_past
|