mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
1209 lines
46 KiB
Python
1209 lines
46 KiB
Python
"""A simple, flexible implementation of a GPT model.
|
|
|
|
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
"""
|
|
|
|
import math
|
|
import os
|
|
import warnings
|
|
from typing import List, Optional, Tuple, Union
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
CausalLMOutputWithPast,
|
|
)
|
|
from einops import rearrange
|
|
from packaging import version
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelEmbedding,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelRowLinear,
|
|
SpeculativeHead,
|
|
get_linear,
|
|
)
|
|
|
|
EPS = 1e-5
|
|
|
|
|
|
def load_col(config, prefix, weights, bias):
|
|
assert config.quantize != "gptq", NotImplementedError
|
|
slice_ = weights._get_slice(f"{prefix}.weight")
|
|
rank = weights.process_group.rank()
|
|
size = weights.process_group.size()
|
|
|
|
h3, h = slice_.get_shape()
|
|
block_size = h // size
|
|
|
|
q_part = slice_[rank * block_size : (rank + 1) * block_size]
|
|
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
|
|
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
|
|
|
|
weight = torch.cat([q_part, k_part, v_part], dim=0)
|
|
if weight.dtype != torch.int32:
|
|
weight = weight.to(dtype=weights.dtype)
|
|
weight = weight.to(device=weights.device)
|
|
|
|
if bias:
|
|
bias_slice_ = weights._get_slice(f"{prefix}.bias")
|
|
bias_rank = weights.process_group.rank()
|
|
bias_size = weights.process_group.size()
|
|
|
|
bias_h = bias_slice_.get_shape()
|
|
bias_h = bias_h[0]
|
|
bias_block_size = bias_h // bias_size
|
|
|
|
bias_q_part = bias_slice_[
|
|
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
|
|
]
|
|
bias_k_part = bias_slice_[
|
|
bias_h
|
|
+ bias_rank * bias_block_size : bias_h
|
|
+ (bias_rank + 1) * bias_block_size
|
|
]
|
|
bias_v_part = bias_slice_[
|
|
2 * bias_h
|
|
+ bias_rank * bias_block_size : 2 * bias_h
|
|
+ (bias_rank + 1) * bias_block_size
|
|
]
|
|
|
|
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
|
|
if bias.dtype != torch.int32:
|
|
bias = bias.to(dtype=weights.dtype)
|
|
bias = bias.to(device=weights.device)
|
|
else:
|
|
bias = None
|
|
linear = get_linear(weight, bias, config.quantize)
|
|
return TensorParallelColumnLinear(linear)
|
|
|
|
|
|
def _reset_is_causal(
|
|
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
|
):
|
|
if original_is_causal and num_query_tokens != num_key_tokens:
|
|
if num_query_tokens != 1:
|
|
raise NotImplementedError(
|
|
"MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
|
|
)
|
|
else:
|
|
return False
|
|
return original_is_causal
|
|
|
|
|
|
def scaled_multihead_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
|
kv_n_heads = 1 if multiquery else n_heads
|
|
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
|
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
k = torch.cat([past_key_value[0], k], dim=3)
|
|
v = torch.cat([past_key_value[1], v], dim=2)
|
|
past_key_value = (k, v)
|
|
(b, _, s_q, d) = q.shape
|
|
s_k = k.size(-1)
|
|
attn_weight = q.matmul(k) * softmax_scale
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - s_q)
|
|
_s_k = max(0, attn_bias.size(3) - s_k)
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if (
|
|
attn_bias.size(-1) != 1
|
|
and attn_bias.size(-1) != s_k
|
|
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
|
|
):
|
|
raise RuntimeError(
|
|
f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
|
|
)
|
|
attn_weight = attn_weight + attn_bias
|
|
min_val = torch.finfo(q.dtype).min
|
|
if key_padding_mask is not None:
|
|
if attn_bias is not None:
|
|
warnings.warn(
|
|
"Propogating key_padding_mask to the attention module "
|
|
+ "and applying it within the attention module can cause "
|
|
+ "unneccessary computation/memory usage. Consider integrating "
|
|
+ "into attn_bias once and passing that to each attention "
|
|
+ "module instead."
|
|
)
|
|
attn_weight = attn_weight.masked_fill(
|
|
~key_padding_mask.view((b, 1, 1, s_k)), min_val
|
|
)
|
|
if is_causal and (not q.size(2) == 1):
|
|
s = max(s_q, s_k)
|
|
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
|
causal_mask = causal_mask.tril()
|
|
causal_mask = causal_mask.to(torch.bool)
|
|
causal_mask = ~causal_mask
|
|
causal_mask = causal_mask[-s_q:, -s_k:]
|
|
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
if dropout_p:
|
|
attn_weight = torch.nn.functional.dropout(
|
|
attn_weight, p=dropout_p, training=training, inplace=True
|
|
)
|
|
out = attn_weight.to(v.dtype).matmul(v)
|
|
out = rearrange(out, "b h s d -> b s (h d)")
|
|
if needs_weights:
|
|
return (out, attn_weight, past_key_value)
|
|
return (out, None, past_key_value)
|
|
|
|
|
|
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
for tensor in tensors:
|
|
if tensor.dtype not in valid_dtypes:
|
|
raise TypeError(
|
|
f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
|
|
)
|
|
if not tensor.is_cuda:
|
|
raise TypeError(
|
|
f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
|
|
)
|
|
|
|
|
|
def flash_attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
try:
|
|
from flash_attn import bert_padding, flash_attn_interface
|
|
except:
|
|
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
|
check_valid_inputs(query, key, value)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
key = torch.cat([past_key_value[0], key], dim=1)
|
|
value = torch.cat([past_key_value[1], value], dim=1)
|
|
past_key_value = (key, value)
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
|
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if attn_bias is not None:
|
|
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
|
(batch_size, seqlen) = query.shape[:2]
|
|
if key_padding_mask is None:
|
|
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
|
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
|
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
|
|
query, query_padding_mask
|
|
)
|
|
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
|
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
|
|
key, key_padding_mask
|
|
)
|
|
key_unpad = rearrange(
|
|
key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
|
)
|
|
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
|
value_unpad = rearrange(
|
|
value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
|
)
|
|
if multiquery:
|
|
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
|
value_unpad = value_unpad.expand(
|
|
value_unpad.size(0), n_heads, value_unpad.size(-1)
|
|
)
|
|
dropout_p = dropout_p if training else 0.0
|
|
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
|
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
|
query_unpad,
|
|
key_unpad,
|
|
value_unpad,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale=softmax_scale,
|
|
causal=reset_is_causal,
|
|
return_attn_probs=needs_weights,
|
|
)
|
|
output = bert_padding.pad_input(
|
|
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
|
|
)
|
|
return (output, None, past_key_value)
|
|
|
|
|
|
def triton_flash_attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
n_heads,
|
|
past_key_value=None,
|
|
softmax_scale=None,
|
|
attn_bias=None,
|
|
key_padding_mask=None,
|
|
is_causal=False,
|
|
dropout_p=0.0,
|
|
training=False,
|
|
needs_weights=False,
|
|
multiquery=False,
|
|
):
|
|
try:
|
|
from .flash_attn_triton import flash_attn_func
|
|
except:
|
|
_installed = False
|
|
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
|
_installed = True
|
|
try:
|
|
from flash_attn.flash_attn_triton import flash_attn_func
|
|
except:
|
|
_installed = False
|
|
if not _installed:
|
|
raise RuntimeError(
|
|
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
|
|
)
|
|
check_valid_inputs(query, key, value)
|
|
if past_key_value is not None:
|
|
if len(past_key_value) != 0:
|
|
key = torch.cat([past_key_value[0], key], dim=1)
|
|
value = torch.cat([past_key_value[1], value], dim=1)
|
|
past_key_value = (key, value)
|
|
if attn_bias is not None:
|
|
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
|
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
|
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
|
if dropout_p:
|
|
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
|
if needs_weights:
|
|
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
|
if key_padding_mask is not None:
|
|
warnings.warn(
|
|
"Propagating key_padding_mask to the attention module "
|
|
+ "and applying it within the attention module can cause "
|
|
+ "unnecessary computation/memory usage. Consider integrating "
|
|
+ "into attn_bias once and passing that to each attention "
|
|
+ "module instead."
|
|
)
|
|
(b_size, s_k) = key_padding_mask.shape[:2]
|
|
if attn_bias is None:
|
|
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
|
attn_bias = attn_bias.masked_fill(
|
|
~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
|
|
)
|
|
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
|
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
|
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
|
if multiquery:
|
|
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
|
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
|
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
|
attn_output = flash_attn_func(
|
|
query, key, value, attn_bias, reset_is_causal, softmax_scale
|
|
)
|
|
output = attn_output.view(*attn_output.shape[:2], -1)
|
|
return (output, None, past_key_value)
|
|
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
"""Multi-head self attention.
|
|
|
|
Using torch or triton attention implementation enables user to also use
|
|
additive bias.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
prefix,
|
|
weights,
|
|
):
|
|
super().__init__()
|
|
attn_impl = config.attn_config["attn_impl"]
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
|
self.qk_ln = config.attn_config["qk_ln"]
|
|
self.d_model = config.d_model
|
|
d_model = config.d_model
|
|
self.n_heads = config.n_heads
|
|
self.softmax_scale = config.attn_config["softmax_scale"]
|
|
if self.softmax_scale is None:
|
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
|
|
|
if self.n_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.n_heads = self.n_heads // weights.process_group.size()
|
|
self.Wqkv = load_col(
|
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
)
|
|
if self.qk_ln:
|
|
bias = not config.no_bias
|
|
hidden_size = config.d_model
|
|
head_dim = hidden_size // self.n_heads
|
|
|
|
self.q_ln = LPLayerNorm(
|
|
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
|
|
)
|
|
self.k_ln = LPLayerNorm(
|
|
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
|
|
)
|
|
if self.attn_impl == "flash":
|
|
self.attn_fn = flash_attn_fn
|
|
elif self.attn_impl == "triton":
|
|
self.attn_fn = triton_flash_attn_fn
|
|
elif self.attn_impl == "torch":
|
|
self.attn_fn = scaled_multihead_dot_product_attention
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
past_key_value=None,
|
|
attn_bias=None,
|
|
attention_mask=None,
|
|
is_causal=True,
|
|
needs_weights=False,
|
|
):
|
|
qkv = self.Wqkv(x)
|
|
if self.clip_qkv:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
(query, key, value) = qkv.chunk(3, dim=2)
|
|
|
|
key_padding_mask = attention_mask
|
|
if self.qk_ln:
|
|
dtype = query.dtype
|
|
query = self.q_ln(query).to(dtype)
|
|
key = self.k_ln(key).to(dtype)
|
|
(context, attn_weights, past_key_value) = self.attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
self.n_heads,
|
|
past_key_value=past_key_value,
|
|
softmax_scale=self.softmax_scale,
|
|
attn_bias=attn_bias,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
dropout_p=self.attn_dropout_p,
|
|
training=self.training,
|
|
needs_weights=needs_weights,
|
|
)
|
|
out = self.out_proj(context)
|
|
return (out, attn_weights, past_key_value)
|
|
|
|
|
|
class MultiQueryAttention(nn.Module):
|
|
"""Multi-Query self attention.
|
|
|
|
Using torch or triton attention implementation enables user to also use
|
|
additive bias.
|
|
"""
|
|
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
attn_impl = config.attn_config["attn_impl"]
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
|
self.qk_ln = config.attn_config["qk_ln"]
|
|
self.d_model = config.d_model
|
|
d_model = config.d_model
|
|
self.n_heads = config.n_heads
|
|
self.softmax_scale = config.attn_config["softmax_scale"]
|
|
if self.softmax_scale is None:
|
|
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
|
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
|
self.Wqkv = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
|
)
|
|
fuse_splits = (d_model, d_model + self.head_dim)
|
|
if self.qk_ln:
|
|
raise NotImplementedError("qk_ln not supported")
|
|
if self.attn_impl == "flash":
|
|
self.attn_fn = flash_attn_fn
|
|
elif self.attn_impl == "triton":
|
|
self.attn_fn = triton_flash_attn_fn
|
|
if verbose:
|
|
warnings.warn(
|
|
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
|
+ "it uses more memory. When training larger models this can trigger "
|
|
+ "alloc retries which hurts performance. If encountered, we recommend "
|
|
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
|
)
|
|
elif self.attn_impl == "torch":
|
|
self.attn_fn = scaled_multihead_dot_product_attention
|
|
if torch.cuda.is_available() and verbose:
|
|
warnings.warn(
|
|
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
|
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
|
+ "we recommend using `attn_impl: triton`."
|
|
)
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
self.out_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.out_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
# self.out_proj._is_residual = True
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
past_key_value=None,
|
|
attn_bias=None,
|
|
attention_mask=None,
|
|
is_causal=True,
|
|
needs_weights=False,
|
|
):
|
|
qkv = self.Wqkv(x)
|
|
if self.clip_qkv:
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
(query, key, value) = qkv.split(
|
|
[self.d_model, self.head_dim, self.head_dim], dim=2
|
|
)
|
|
key_padding_mask = attention_mask
|
|
if self.qk_ln:
|
|
dtype = query.dtype
|
|
query = self.q_ln(query).to(dtype)
|
|
key = self.k_ln(key).to(dtype)
|
|
(context, attn_weights, past_key_value) = self.attn_fn(
|
|
query,
|
|
key,
|
|
value,
|
|
self.n_heads,
|
|
past_key_value=past_key_value,
|
|
softmax_scale=self.softmax_scale,
|
|
attn_bias=attn_bias,
|
|
key_padding_mask=key_padding_mask,
|
|
is_causal=is_causal,
|
|
dropout_p=self.attn_dropout_p,
|
|
training=self.training,
|
|
needs_weights=needs_weights,
|
|
multiquery=True,
|
|
)
|
|
return (self.out_proj(context), attn_weights, past_key_value)
|
|
|
|
|
|
def attn_bias_shape(
|
|
attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
|
|
):
|
|
if attn_impl == "flash":
|
|
return None
|
|
elif attn_impl in ["torch", "triton"]:
|
|
if alibi:
|
|
if (prefix_lm or not causal) or use_sequence_id:
|
|
return (1, n_heads, seq_len, seq_len)
|
|
return (1, n_heads, 1, seq_len)
|
|
elif prefix_lm or use_sequence_id:
|
|
return (1, 1, seq_len, seq_len)
|
|
return None
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
|
|
|
|
def build_attn_bias(
|
|
attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
|
|
):
|
|
if attn_impl == "flash":
|
|
return None
|
|
elif attn_impl in ["torch", "triton"]:
|
|
if alibi:
|
|
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
|
attn_bias = attn_bias.add(
|
|
build_alibi_bias(
|
|
n_heads,
|
|
seq_len,
|
|
full=not causal,
|
|
alibi_bias_max=alibi_bias_max,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
return attn_bias
|
|
else:
|
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
|
|
|
|
|
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
|
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
|
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
|
m = m.mul(alibi_bias_max / _n_heads)
|
|
slopes = 1.0 / torch.pow(2, m)
|
|
if _n_heads != n_heads:
|
|
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
|
return slopes.view(1, n_heads, 1, 1)
|
|
|
|
|
|
def build_alibi_bias(
|
|
n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
|
|
):
|
|
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
|
|
1, 1, 1, seq_len
|
|
)
|
|
if full:
|
|
alibi_bias = alibi_bias - torch.arange(
|
|
1 - seq_len, 1, dtype=torch.int32, device=device
|
|
).view(1, 1, seq_len, 1)
|
|
alibi_bias = alibi_bias.abs().mul(-1)
|
|
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
|
alibi_bias = alibi_bias * slopes
|
|
return alibi_bias.to(dtype=dtype)
|
|
|
|
|
|
ATTN_CLASS_REGISTRY = {
|
|
"multihead_attention": MultiheadAttention,
|
|
"multiquery_attention": MultiQueryAttention,
|
|
}
|
|
|
|
"""GPT Blocks used for the GPT Model."""
|
|
|
|
|
|
class MPTMLP(nn.Module):
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
# self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
|
self.up_proj = TensorParallelColumnLinear.load(
|
|
config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias
|
|
)
|
|
self.act = nn.GELU(approximate="none")
|
|
# self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
|
self.down_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.down_proj",
|
|
weights=weights,
|
|
bias=not config.no_bias,
|
|
)
|
|
# self.down_proj._is_residual = True
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act(self.up_proj(x)))
|
|
|
|
|
|
class MPTBlock(nn.Module):
|
|
def __init__(self, config, prefix, weights):
|
|
super().__init__()
|
|
self.prefix = prefix
|
|
if config.attn_config["attn_type"] != "multihead_attention":
|
|
raise NotImplementedError(
|
|
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
|
)
|
|
resid_pdrop = config.resid_pdrop
|
|
if config.no_bias:
|
|
self.norm_1 = nn.LayerNorm.load_no_bias(
|
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
|
)
|
|
self.norm_2 = nn.LayerNorm.load_no_bias(
|
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
|
)
|
|
else:
|
|
self.norm_1 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
|
)
|
|
self.norm_2 = nn.LayerNorm.load(
|
|
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
|
|
)
|
|
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
|
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
|
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
attn_bias: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
is_causal: bool = True,
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
|
a = self.norm_1(x)
|
|
(b, attn_weights, past_key_value) = self.attn(
|
|
a,
|
|
past_key_value=past_key_value,
|
|
attn_bias=attn_bias,
|
|
attention_mask=attention_mask,
|
|
is_causal=is_causal,
|
|
)
|
|
x = x + self.resid_attn_dropout(b)
|
|
m = self.norm_2(x)
|
|
n = self.ffn(m)
|
|
x = x + self.resid_ffn_dropout(n)
|
|
return (x, attn_weights, past_key_value)
|
|
|
|
|
|
def _cast_if_autocast_enabled(tensor):
|
|
if torch.is_autocast_enabled():
|
|
if tensor.device.type == "cuda":
|
|
dtype = torch.get_autocast_gpu_dtype()
|
|
elif tensor.device.type == "cpu":
|
|
dtype = torch.get_autocast_cpu_dtype()
|
|
else:
|
|
raise NotImplementedError()
|
|
return tensor.to(dtype=dtype)
|
|
return tensor
|
|
|
|
|
|
class LPLayerNorm(torch.nn.LayerNorm):
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=1e-05,
|
|
elementwise_affine=True,
|
|
device=None,
|
|
dtype=None,
|
|
bias: Optional[bool] = True,
|
|
prefix=None,
|
|
weights=None,
|
|
):
|
|
super().__init__(
|
|
normalized_shape=normalized_shape,
|
|
eps=eps,
|
|
elementwise_affine=elementwise_affine,
|
|
device=device,
|
|
dtype=dtype,
|
|
bias=bias,
|
|
)
|
|
if weights is not None:
|
|
self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
|
|
if bias:
|
|
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
|
|
self.normalized_shape = self.weight.shape
|
|
|
|
def forward(self, x):
|
|
module_device = x.device
|
|
downcast_x = _cast_if_autocast_enabled(x)
|
|
downcast_weight = (
|
|
_cast_if_autocast_enabled(self.weight)
|
|
if self.weight is not None
|
|
else self.weight
|
|
)
|
|
downcast_bias = (
|
|
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
|
)
|
|
with torch.autocast(enabled=False, device_type=module_device.type):
|
|
return torch.nn.functional.layer_norm(
|
|
downcast_x,
|
|
self.normalized_shape,
|
|
downcast_weight,
|
|
downcast_bias,
|
|
self.eps,
|
|
)
|
|
|
|
|
|
def rms_norm(x, weight=None, eps=1e-05):
|
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
|
if weight is not None:
|
|
return output * weight
|
|
return output
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(
|
|
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
if weight:
|
|
self.weight = torch.nn.Parameter(
|
|
torch.ones(normalized_shape, dtype=dtype, device=device)
|
|
)
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
|
|
def forward(self, x):
|
|
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
|
|
|
|
|
class LPRMSNorm(RMSNorm):
|
|
def __init__(
|
|
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
|
):
|
|
super().__init__(
|
|
normalized_shape=normalized_shape,
|
|
eps=eps,
|
|
weight=weight,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def forward(self, x):
|
|
downcast_x = _cast_if_autocast_enabled(x)
|
|
downcast_weight = (
|
|
_cast_if_autocast_enabled(self.weight)
|
|
if self.weight is not None
|
|
else self.weight
|
|
)
|
|
with torch.autocast(enabled=False, device_type=x.device.type):
|
|
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
|
|
|
|
|
NORM_CLASS_REGISTRY = {
|
|
"layernorm": torch.nn.LayerNorm,
|
|
"low_precision_layernorm": LPLayerNorm,
|
|
"rmsnorm": RMSNorm,
|
|
"low_precision_rmsnorm": LPRMSNorm,
|
|
}
|
|
|
|
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
|
|
|
|
class MPTPreTrainedModel(PreTrainedModel):
|
|
base_model_prefix = "model"
|
|
_no_split_modules = ["MPTBlock"]
|
|
|
|
|
|
class MPTModel(MPTPreTrainedModel):
|
|
def __init__(self, config, weights):
|
|
# config._validate_config()
|
|
super().__init__(config)
|
|
self.world_size = weights.process_group.size()
|
|
self.rank = weights.process_group.rank()
|
|
self.n_heads = config.n_heads
|
|
self.attn_impl = config.attn_config["attn_impl"]
|
|
self.prefix_lm = config.attn_config["prefix_lm"]
|
|
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
|
self.alibi = config.attn_config["alibi"]
|
|
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
|
if config.init_device == "mixed":
|
|
if dist.get_local_rank() == 0:
|
|
config.init_device = "cpu"
|
|
else:
|
|
config.init_device = "meta"
|
|
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
|
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
|
raise NotImplementedError(
|
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
|
)
|
|
if config.norm_type.lower() != "low_precision_layernorm":
|
|
raise NotImplementedError(
|
|
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
|
)
|
|
|
|
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
|
|
|
if not self.alibi:
|
|
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
|
for i in range(config.n_layers)
|
|
]
|
|
)
|
|
if config.no_bias:
|
|
self.norm_f = nn.LayerNorm.load_no_bias(
|
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
|
)
|
|
else:
|
|
self.norm_f = nn.LayerNorm.load(
|
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
|
)
|
|
self.is_causal = not self.prefix_lm
|
|
self._attn_bias_initialized = False
|
|
self.attn_bias = None
|
|
self.attn_bias_shape = attn_bias_shape(
|
|
self.attn_impl,
|
|
config.n_heads,
|
|
config.max_seq_len,
|
|
self.alibi,
|
|
prefix_lm=self.prefix_lm,
|
|
causal=self.is_causal,
|
|
use_sequence_id=self.attn_uses_sequence_id,
|
|
)
|
|
if config.no_bias:
|
|
for module in self.modules():
|
|
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
|
if config.verbose:
|
|
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
|
module.register_parameter("bias", None)
|
|
if hasattr(self.config, "verbose"):
|
|
if config.verbose and config.verbose > 2:
|
|
print(self)
|
|
if "verbose" not in self.config.init_config:
|
|
self.config.init_config["verbose"] = self.config.verbose
|
|
if self.config.init_config["verbose"] > 1:
|
|
init_fn_name = self.config.init_config["name"]
|
|
warnings.warn(f"Using {init_fn_name} initialization.")
|
|
|
|
@torch.no_grad()
|
|
def _attn_bias(
|
|
self,
|
|
device,
|
|
dtype,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
):
|
|
if not self._attn_bias_initialized:
|
|
if self.attn_bias_shape:
|
|
self.attn_bias = torch.zeros(
|
|
self.attn_bias_shape, device=device, dtype=dtype
|
|
)
|
|
self.attn_bias = build_attn_bias(
|
|
self.attn_impl,
|
|
self.attn_bias,
|
|
self.config.n_heads,
|
|
self.config.max_seq_len,
|
|
causal=self.is_causal,
|
|
alibi=self.alibi,
|
|
alibi_bias_max=self.alibi_bias_max,
|
|
)
|
|
assert self.n_heads % self.world_size == 0
|
|
block_size = self.n_heads // self.world_size
|
|
self.attn_bias = self.attn_bias[
|
|
:, self.rank * block_size : (self.rank + 1) * block_size
|
|
]
|
|
self._attn_bias_initialized = True
|
|
if self.attn_impl == "flash":
|
|
return (self.attn_bias, attention_mask)
|
|
if self.attn_bias is not None:
|
|
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
|
attn_bias = self.attn_bias
|
|
if self.prefix_lm:
|
|
assert isinstance(attn_bias, torch.Tensor)
|
|
assert isinstance(prefix_mask, torch.Tensor)
|
|
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
|
if self.attn_uses_sequence_id and sequence_id is not None:
|
|
assert isinstance(attn_bias, torch.Tensor)
|
|
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
|
if attention_mask is not None:
|
|
s_k = attention_mask.shape[-1]
|
|
if attn_bias is None:
|
|
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
|
else:
|
|
_s_k = max(0, attn_bias.size(-1) - s_k)
|
|
attn_bias = attn_bias[:, :, :, _s_k:]
|
|
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
|
raise ValueError(
|
|
f"attention_mask shape={attention_mask.shape} "
|
|
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
|
)
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(
|
|
~attention_mask.view(-1, 1, 1, s_k), min_val
|
|
)
|
|
return (attn_bias, None)
|
|
|
|
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
|
(s_k, s_q) = attn_bias.shape[-2:]
|
|
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
|
raise ValueError(
|
|
"attn_bias does not match the expected shape. "
|
|
+ f"The last two dimensions should both be {self.config.max_length} "
|
|
+ f"but are {s_k} and {s_q}."
|
|
)
|
|
seq_len = prefix_mask.shape[-1]
|
|
if seq_len > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
|
)
|
|
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
|
causal = torch.tril(
|
|
torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
|
|
).view(1, 1, seq_len, seq_len)
|
|
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
|
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
|
return attn_bias
|
|
|
|
def _apply_sequence_id(
|
|
self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
|
|
):
|
|
seq_len = sequence_id.shape[-1]
|
|
if seq_len > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
|
)
|
|
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
|
cannot_attend = torch.logical_not(
|
|
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
|
).unsqueeze(1)
|
|
min_val = torch.finfo(attn_bias.dtype).min
|
|
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
|
return attn_bias
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
):
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.return_dict
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.bool()
|
|
if prefix_mask is not None:
|
|
prefix_mask = prefix_mask.bool()
|
|
if not return_dict:
|
|
raise NotImplementedError(
|
|
"return_dict False is not implemented yet for MPT"
|
|
)
|
|
if output_attentions:
|
|
if self.attn_impl != "torch":
|
|
raise NotImplementedError(
|
|
"output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
|
|
)
|
|
if (
|
|
attention_mask is not None
|
|
and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
|
and self.training
|
|
):
|
|
raise NotImplementedError(
|
|
"MPT does not support training with left padding."
|
|
)
|
|
if self.prefix_lm and prefix_mask is None:
|
|
raise ValueError(
|
|
"prefix_mask is a required argument when MPT is configured with prefix_lm=True."
|
|
)
|
|
if self.training:
|
|
if self.attn_uses_sequence_id and sequence_id is None:
|
|
raise ValueError(
|
|
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
|
+ "and the model is in train mode."
|
|
)
|
|
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
|
warnings.warn(
|
|
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
|
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
|
)
|
|
S = input_ids.size(1)
|
|
assert (
|
|
S <= self.config.max_seq_len
|
|
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
|
tok_emb = self.wte(input_ids)
|
|
if self.alibi:
|
|
x = tok_emb
|
|
else:
|
|
past_position = 0
|
|
if past_key_values is not None:
|
|
if len(past_key_values) != self.config.n_layers:
|
|
raise ValueError(
|
|
f"past_key_values must provide a past_key_value for each attention "
|
|
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
|
)
|
|
past_position = past_key_values[0][0].size(1)
|
|
if self.attn_impl == "torch":
|
|
past_position = past_key_values[0][0].size(3)
|
|
if S + past_position > self.config.max_seq_len:
|
|
raise ValueError(
|
|
f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
|
)
|
|
pos = torch.arange(
|
|
past_position,
|
|
S + past_position,
|
|
dtype=torch.long,
|
|
device=input_ids.device,
|
|
).unsqueeze(0)
|
|
if attention_mask is not None:
|
|
pos = torch.clamp(
|
|
pos
|
|
- torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
|
|
:, past_position:
|
|
],
|
|
min=0,
|
|
)
|
|
pos_emb = self.wpe(pos)
|
|
x = tok_emb + pos_emb
|
|
(attn_bias, attention_mask) = self._attn_bias(
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
attention_mask=attention_mask,
|
|
prefix_mask=prefix_mask,
|
|
sequence_id=sequence_id,
|
|
)
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = [() for _ in range(self.config.n_layers)]
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
for b_idx, block in enumerate(self.blocks):
|
|
if output_hidden_states:
|
|
assert all_hidden_states is not None
|
|
all_hidden_states = all_hidden_states + (x,)
|
|
past_key_value = (
|
|
past_key_values[b_idx] if past_key_values is not None else None
|
|
)
|
|
(x, attn_weights, past_key_value) = block(
|
|
x,
|
|
past_key_value=past_key_value,
|
|
attn_bias=attn_bias,
|
|
attention_mask=attention_mask,
|
|
is_causal=self.is_causal,
|
|
)
|
|
if past_key_values is not None:
|
|
past_key_values[b_idx] = past_key_value
|
|
if output_attentions:
|
|
assert all_self_attns is not None
|
|
all_self_attns = all_self_attns + (attn_weights,)
|
|
x = self.norm_f(x)
|
|
if output_hidden_states:
|
|
assert all_hidden_states is not None
|
|
all_hidden_states = all_hidden_states + (x,)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=x,
|
|
past_key_values=past_key_values,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
class MPTForCausalLM(MPTPreTrainedModel):
|
|
def __init__(self, config, weights):
|
|
super().__init__(config)
|
|
if not config.tie_word_embeddings:
|
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
|
self.transformer = MPTModel(config, weights)
|
|
self.lm_head = SpeculativeHead.load(
|
|
config, prefix="transformer.wte", weights=weights
|
|
)
|
|
self.logit_scale = None
|
|
if config.logit_scale is not None:
|
|
logit_scale = config.logit_scale
|
|
if isinstance(logit_scale, str):
|
|
if logit_scale == "inv_sqrt_d_model":
|
|
logit_scale = 1 / math.sqrt(config.d_model)
|
|
else:
|
|
raise ValueError(
|
|
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
|
)
|
|
self.logit_scale = logit_scale
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
|
attention_mask: Optional[torch.ByteTensor] = None,
|
|
prefix_mask: Optional[torch.ByteTensor] = None,
|
|
sequence_id: Optional[torch.LongTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
use_cache: Optional[bool] = None,
|
|
):
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.return_dict
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
outputs = self.transformer(
|
|
input_ids=input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
prefix_mask=prefix_mask,
|
|
sequence_id=sequence_id,
|
|
return_dict=return_dict,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
use_cache=use_cache,
|
|
)
|
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
|
if self.logit_scale is not None:
|
|
if self.logit_scale == 0:
|
|
warnings.warn(
|
|
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
|
)
|
|
logits *= self.logit_scale
|
|
loss = None
|
|
if labels is not None:
|
|
labels = torch.roll(labels, shifts=-1)
|
|
labels[:, -1] = -100
|
|
loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
|
)
|
|
return (
|
|
CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
),
|
|
speculative_logits,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
|
):
|
|
if inputs_embeds is not None:
|
|
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
|
attention_mask = kwargs["attention_mask"].bool()
|
|
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
|
raise NotImplementedError(
|
|
"MPT does not support generation with right padding."
|
|
)
|
|
if self.transformer.attn_uses_sequence_id and self.training:
|
|
sequence_id = torch.zeros_like(input_ids[:1])
|
|
else:
|
|
sequence_id = None
|
|
if past_key_values is not None:
|
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
if self.transformer.prefix_lm:
|
|
prefix_mask = torch.ones_like(attention_mask)
|
|
if kwargs.get("use_cache") == False:
|
|
raise NotImplementedError(
|
|
"MPT with prefix_lm=True does not support use_cache=False."
|
|
)
|
|
else:
|
|
prefix_mask = None
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"prefix_mask": prefix_mask,
|
|
"sequence_id": sequence_id,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache", True),
|
|
}
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past_key_values, beam_idx):
|
|
"""Used by HuggingFace generate when using beam search with kv-caching.
|
|
|
|
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
|
for an example in transformers.
|
|
"""
|
|
reordered_past = []
|
|
for layer_past in past_key_values:
|
|
reordered_past += [
|
|
tuple(
|
|
(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
|
)
|
|
]
|
|
return reordered_past
|