mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
924 lines
34 KiB
Python
924 lines
34 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch BLOOM model."""
|
|
|
|
import math
|
|
import os
|
|
import warnings
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import LayerNorm
|
|
from torch.nn import functional as F
|
|
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
CausalLMOutputWithCrossAttentions,
|
|
)
|
|
from transformers import BloomConfig, PreTrainedModel
|
|
|
|
from text_generation_server.layers import (
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
SpeculativeHead,
|
|
)
|
|
|
|
CUSTOM_KERNELS_ENABLED = False
|
|
if (
|
|
torch.cuda.is_available()
|
|
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
|
|
):
|
|
try:
|
|
from custom_kernels import fused_bloom_attention_cuda
|
|
|
|
CUSTOM_KERNELS_ENABLED = True
|
|
except ImportError:
|
|
pass
|
|
|
|
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
|
_CONFIG_FOR_DOC = "BloomConfig"
|
|
|
|
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
"bigscience/bigscience-small-testing",
|
|
"bigscience/bloom-560m",
|
|
"bigscience/bloom-1b1",
|
|
"bigscience/bloom-1b7",
|
|
"bigscience/bloom-3b",
|
|
"bigscience/bloom-7b1",
|
|
"bigscience/bloom",
|
|
]
|
|
|
|
|
|
def _make_causal_mask(
|
|
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
|
) -> torch.BoolTensor:
|
|
"""
|
|
Make causal mask used for self-attention.
|
|
"""
|
|
batch_size, target_length = input_ids_shape
|
|
mask = torch.ones(
|
|
(target_length, target_length + past_key_values_length),
|
|
dtype=torch.bool,
|
|
device=device,
|
|
)
|
|
mask = mask.triu(1 + past_key_values_length)
|
|
|
|
expanded_mask = mask.unsqueeze(0).expand(
|
|
batch_size, target_length, target_length + past_key_values_length
|
|
)
|
|
return expanded_mask
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
|
"""
|
|
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
|
"""
|
|
batch_size, src_length = mask.shape
|
|
tgt_length = tgt_length if tgt_length is not None else src_length
|
|
|
|
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
|
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
|
|
|
|
|
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
|
|
"""
|
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
|
`softmax(l+a) = softmax(l)`. Based on
|
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
|
|
|
Args:
|
|
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
|
attention_mask (`torch.Tensor`):
|
|
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
|
num_heads (`int`, *required*):
|
|
number of heads
|
|
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
|
dtype of the output tensor
|
|
"""
|
|
batch_size, seq_length = attention_mask.shape
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
|
base = torch.tensor(
|
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
|
device=attention_mask.device,
|
|
dtype=torch.float32,
|
|
)
|
|
powers = torch.arange(
|
|
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
|
|
)
|
|
slopes = torch.pow(base, powers)
|
|
|
|
if closest_power_of_2 != num_heads:
|
|
extra_base = torch.tensor(
|
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
|
device=attention_mask.device,
|
|
dtype=torch.float32,
|
|
)
|
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
|
extra_powers = torch.arange(
|
|
1,
|
|
1 + 2 * num_remaining_heads,
|
|
2,
|
|
device=attention_mask.device,
|
|
dtype=torch.int32,
|
|
)
|
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
|
|
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
|
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
|
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
|
# => the query_length dimension will then be broadcasted correctly
|
|
# This is more or less identical to T5's relative position bias:
|
|
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
|
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
|
alibi = slopes[..., None] * arange_tensor
|
|
return alibi
|
|
|
|
|
|
# @torch.jit.script
|
|
def dropout_add(
|
|
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
|
|
) -> torch.Tensor:
|
|
"""
|
|
Dropout add function
|
|
|
|
Args:
|
|
x (`torch.tensor`, *required*):
|
|
input tensor
|
|
residual (`torch.tensor`, *required*):
|
|
esidual tensor
|
|
prob (`float`, *required*):
|
|
dropout probability
|
|
training (`bool`, *required*):
|
|
training mode
|
|
"""
|
|
out = F.dropout(x, p=prob, training=training)
|
|
out = residual + out
|
|
return out
|
|
|
|
|
|
# @torch.jit.script # this is shit for unknow reasons.
|
|
def _split_heads(
|
|
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
|
storage as `fused_qkv`
|
|
|
|
Args:
|
|
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
|
|
|
Returns:
|
|
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
|
value: [batch_size, seq_length, num_heads, head_dim]
|
|
"""
|
|
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
|
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
|
|
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
|
|
|
|
query_layer = query_layer.transpose(1, 2).reshape(
|
|
batch_size * num_heads, seq_length, head_dim
|
|
)
|
|
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
|
|
batch_size * num_heads, head_dim, seq_length
|
|
)
|
|
value_layer = value_layer.transpose(1, 2).reshape(
|
|
batch_size * num_heads, seq_length, head_dim
|
|
)
|
|
|
|
return query_layer, key_layer, value_layer
|
|
|
|
|
|
# @torch.jit.script
|
|
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
|
|
"""
|
|
Merge heads together over the last dimenstion
|
|
|
|
Args:
|
|
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
|
|
|
Returns:
|
|
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
|
"""
|
|
# What we want to achieve is:
|
|
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
|
batch_size_and_num_heads, seq_length, _ = x.shape
|
|
batch_size = batch_size_and_num_heads // num_heads
|
|
|
|
# First view to decompose the batch size
|
|
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
|
x = x.view(batch_size, num_heads, seq_length, head_dim)
|
|
|
|
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
|
x = x.permute(0, 2, 1, 3)
|
|
|
|
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
|
return x.reshape(batch_size, seq_length, num_heads * head_dim)
|
|
|
|
|
|
class BloomAttention(nn.Module):
|
|
def __init__(self, prefix, config: BloomConfig, weights):
|
|
super().__init__()
|
|
|
|
self.pretraining_tp = config.pretraining_tp
|
|
self.slow_but_exact = config.slow_but_exact
|
|
|
|
self.process_group = weights.process_group
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.n_head
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.split_size = self.hidden_size
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
|
if self.head_dim * self.num_heads != self.hidden_size:
|
|
raise ValueError(
|
|
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
|
|
# Layer-wise attention scaling
|
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
self.beta = 1.0
|
|
|
|
process_group = weights.process_group
|
|
if self.num_heads % process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
f"and `num_shards`: {process_group.size()}"
|
|
)
|
|
self.num_heads = self.num_heads // process_group.size()
|
|
self.query_key_value = TensorParallelColumnLinear.load(
|
|
config=config,
|
|
prefix=f"{prefix}.query_key_value",
|
|
weights=weights,
|
|
bias=True,
|
|
)
|
|
self.dense = TensorParallelRowLinear.load(
|
|
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
|
)
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
@staticmethod
|
|
def compute_attention(
|
|
fused_qkv: torch.Tensor,
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
|
alibi: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
head_mask: Optional[torch.Tensor],
|
|
beta: float,
|
|
inv_norm_factor: float,
|
|
num_heads: int,
|
|
use_cache: bool,
|
|
):
|
|
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
|
|
head_dim = three_times_hidden_size // (3 * num_heads)
|
|
batch_size * num_heads
|
|
|
|
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
(query_layer, key_layer, value_layer) = _split_heads(
|
|
fused_qkv, num_heads=num_heads, head_dim=head_dim
|
|
)
|
|
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
# concatenate along seq_length dimension:
|
|
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
|
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
past_key = past_key.view(-1, *past_key.shape[-2:])
|
|
key_layer = torch.cat((past_key, key_layer), dim=2)
|
|
past_value = past_value.view(-1, *past_value.shape[-2:])
|
|
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
|
|
_, _, kv_length = key_layer.shape
|
|
|
|
if use_cache is True:
|
|
present = (key_layer, value_layer)
|
|
else:
|
|
present = None
|
|
###
|
|
|
|
# [batch_size * num_heads, q_length, kv_length]
|
|
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
|
attention_scores = alibi.baddbmm(
|
|
batch1=query_layer,
|
|
batch2=key_layer,
|
|
beta=beta,
|
|
alpha=inv_norm_factor,
|
|
)
|
|
|
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
|
input_dtype = attention_scores.dtype
|
|
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
if input_dtype == torch.float16:
|
|
attention_scores = attention_scores.to(torch.float)
|
|
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
|
|
attn_weights = attention_scores.masked_fill_(
|
|
attention_mask, torch.finfo(attention_scores.dtype).min
|
|
)
|
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
input_dtype
|
|
)
|
|
|
|
# # [batch_size, num_heads, q_length, kv_length]
|
|
# attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
|
|
|
|
# change view [batch_size, num_heads, q_length, head_dim]
|
|
context_layer = _merge_heads(
|
|
context_layer, num_heads=num_heads, head_dim=head_dim
|
|
)
|
|
|
|
return context_layer, present, attention_probs
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
alibi: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
use_cache: bool = False,
|
|
output_attentions: bool = False,
|
|
):
|
|
fused_qkv = self.query_key_value(
|
|
hidden_states
|
|
) # [batch_size, seq_length, 3 x hidden_size]
|
|
batch_size, q_length, _ = fused_qkv.shape
|
|
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
layer_past = (
|
|
past_key.view(-1, *past_key.shape[-2:]),
|
|
past_value.view(-1, *past_value.shape[-2:]),
|
|
)
|
|
|
|
if CUSTOM_KERNELS_ENABLED:
|
|
assert self.training is False, "Only foward pass was implemented"
|
|
assert (
|
|
attention_mask.shape[-1] < 4096
|
|
), "Custom kernel support only up to 4096 tokens"
|
|
(
|
|
context_layer,
|
|
present,
|
|
attention_probs,
|
|
) = fused_bloom_attention_cuda.forward(
|
|
fused_qkv,
|
|
layer_past,
|
|
alibi,
|
|
attention_mask,
|
|
head_mask,
|
|
self.beta,
|
|
self.inv_norm_factor,
|
|
self.num_heads,
|
|
use_cache,
|
|
)
|
|
else:
|
|
context_layer, present, attention_probs = self.compute_attention(
|
|
fused_qkv=fused_qkv,
|
|
layer_past=layer_past,
|
|
alibi=alibi,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
beta=self.beta,
|
|
inv_norm_factor=self.inv_norm_factor,
|
|
num_heads=self.num_heads,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
slices = self.hidden_size / self.pretraining_tp
|
|
output_tensor = torch.zeros_like(context_layer)
|
|
for i in range(self.pretraining_tp):
|
|
output_tensor = output_tensor + F.linear(
|
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
|
)
|
|
else:
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
|
output_tensor += residual
|
|
|
|
outputs = (output_tensor, present)
|
|
if output_attentions:
|
|
outputs += (attention_probs,)
|
|
|
|
return outputs
|
|
|
|
|
|
class BloomMLP(nn.Module):
|
|
def __init__(self, prefix, config: BloomConfig, weights):
|
|
super().__init__()
|
|
|
|
self.pretraining_tp = config.pretraining_tp
|
|
self.slow_but_exact = config.slow_but_exact
|
|
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
|
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
|
)
|
|
self.dense_4h_to_h = TensorParallelRowLinear.load(
|
|
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
|
)
|
|
self.gelu_impl = torch.nn.GELU(approximate="tanh")
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, residual: torch.Tensor
|
|
) -> torch.Tensor:
|
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
intermediate_output = torch.zeros_like(residual)
|
|
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
|
for i in range(self.pretraining_tp):
|
|
intermediate_output = intermediate_output + F.linear(
|
|
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
self.dense_4h_to_h.weight[
|
|
:, int(i * slices) : int((i + 1) * slices)
|
|
],
|
|
)
|
|
else:
|
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
|
|
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
|
intermediate_output += residual
|
|
|
|
return intermediate_output
|
|
|
|
|
|
class BloomBlock(nn.Module):
|
|
def __init__(self, layer_id: int, config: BloomConfig, weights):
|
|
super().__init__()
|
|
|
|
prefix = f"h.{layer_id}"
|
|
self.input_layernorm = LayerNorm.load(
|
|
prefix=f"{prefix}.input_layernorm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
self.num_heads = config.n_head
|
|
self.self_attention = BloomAttention(
|
|
prefix=f"{prefix}.self_attention", config=config, weights=weights
|
|
)
|
|
self.post_attention_layernorm = LayerNorm.load(
|
|
prefix=f"{prefix}.post_attention_layernorm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
|
|
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
self.apply_residual_connection_post_layernorm = (
|
|
config.apply_residual_connection_post_layernorm
|
|
)
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
alibi: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
use_cache: bool = False,
|
|
output_attentions: bool = False,
|
|
):
|
|
# hidden_states: [batch_size, seq_length, hidden_size]
|
|
|
|
# Layer norm at the beginning of the transformer layer.
|
|
layernorm_output = self.input_layernorm(hidden_states)
|
|
|
|
# Layer norm post the self attention.
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = hidden_states
|
|
|
|
# Self attention.
|
|
attn_outputs = self.self_attention(
|
|
layernorm_output,
|
|
residual,
|
|
layer_past=layer_past,
|
|
attention_mask=attention_mask,
|
|
alibi=alibi,
|
|
head_mask=head_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
attention_output = attn_outputs[0]
|
|
|
|
outputs = attn_outputs[1:]
|
|
|
|
layernorm_output = self.post_attention_layernorm(attention_output)
|
|
|
|
# Get residual
|
|
if self.apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = attention_output
|
|
|
|
# MLP.
|
|
output = self.mlp(layernorm_output, residual)
|
|
|
|
if use_cache:
|
|
outputs = (output,) + outputs
|
|
else:
|
|
outputs = (output,) + outputs[1:]
|
|
|
|
return outputs # hidden_states, present, attentions
|
|
|
|
|
|
class BloomPreTrainedModel(PreTrainedModel):
|
|
config_class = BloomConfig
|
|
base_model_prefix = "transformer"
|
|
_no_split_modules = ["BloomBlock"]
|
|
|
|
@staticmethod
|
|
def _convert_to_standard_cache(
|
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
|
num_heads, ...]))
|
|
"""
|
|
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
|
num_heads = batch_size_times_num_heads // batch_size
|
|
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
|
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
|
return tuple(
|
|
(
|
|
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
|
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
|
)
|
|
for layer_past in past_key_value
|
|
)
|
|
|
|
@staticmethod
|
|
def _convert_to_bloom_cache(
|
|
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
|
"""
|
|
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
|
batch_size_times_num_heads = batch_size * num_heads
|
|
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
|
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
|
return tuple(
|
|
(
|
|
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
|
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
|
)
|
|
for layer_past in past_key_value
|
|
)
|
|
|
|
|
|
class BloomModel(BloomPreTrainedModel):
|
|
def __init__(self, config: BloomConfig, weights):
|
|
super().__init__(config)
|
|
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.n_head
|
|
|
|
process_group = weights.process_group
|
|
self.tp_rank = process_group.rank()
|
|
self.tp_world_size = process_group.size()
|
|
|
|
self.word_embeddings = TensorParallelEmbedding(
|
|
prefix="word_embeddings", weights=weights
|
|
)
|
|
|
|
self.word_embeddings_layernorm = LayerNorm.load(
|
|
prefix="word_embeddings_layernorm",
|
|
weights=weights,
|
|
eps=config.layer_norm_epsilon,
|
|
)
|
|
|
|
# Transformer blocks
|
|
self.h = nn.ModuleList(
|
|
[
|
|
BloomBlock(layer_id=layer_id, config=config, weights=weights)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
# Final Layer Norm
|
|
self.ln_f = LayerNorm.load(
|
|
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
|
|
)
|
|
|
|
def _prepare_attn_mask(
|
|
self,
|
|
attention_mask: torch.Tensor,
|
|
input_shape: Tuple[int, int],
|
|
past_key_values_length: int,
|
|
) -> torch.BoolTensor:
|
|
# create causal mask
|
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
|
combined_attention_mask = None
|
|
device = attention_mask.device
|
|
_, src_length = input_shape
|
|
|
|
if src_length > 1:
|
|
combined_attention_mask = _make_causal_mask(
|
|
input_shape,
|
|
device=device,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
|
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
combined_attention_mask = (
|
|
expanded_attn_mask
|
|
if combined_attention_mask is None
|
|
else expanded_attn_mask | combined_attention_mask
|
|
)
|
|
|
|
return combined_attention_mask
|
|
|
|
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
|
self.word_embeddings = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**deprecated_arguments,
|
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
warnings.warn(
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
" passing `position_ids`.",
|
|
FutureWarning,
|
|
)
|
|
if len(deprecated_arguments) > 0:
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
|
)
|
|
elif input_ids is not None:
|
|
batch_size, seq_length = input_ids.shape
|
|
elif inputs_embeds is not None:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
if past_key_values is None:
|
|
past_key_values = tuple([None] * len(self.h))
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape batch_size x num_heads x N x N
|
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
presents = () if use_cache else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
|
seq_length_with_past = seq_length
|
|
past_key_values_length = 0
|
|
if past_key_values[0] is not None:
|
|
past_key_values_length = past_key_values[0][0].shape[-1]
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
(batch_size, seq_length_with_past), device=hidden_states.device
|
|
)
|
|
else:
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
alibi = build_alibi_tensor(attention_mask, self.num_heads)
|
|
|
|
causal_mask = self._prepare_attn_mask(
|
|
attention_mask,
|
|
input_shape=(batch_size, seq_length),
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
if hasattr(self, "tp_rank"):
|
|
assert self.num_heads % self.tp_world_size == 0
|
|
block_size = self.num_heads // self.tp_world_size
|
|
alibi = alibi[
|
|
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
|
|
]
|
|
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
|
|
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
|
else:
|
|
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
|
|
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
|
|
|
|
alibi = alibi.to(hidden_states.dtype)
|
|
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
outputs = block(
|
|
hidden_states,
|
|
layer_past=layer_past,
|
|
attention_mask=causal_mask,
|
|
head_mask=head_mask[i],
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
alibi=alibi,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
if use_cache is True:
|
|
presents = presents + (outputs[1],)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (
|
|
outputs[2 if use_cache else 1],
|
|
)
|
|
|
|
# Add last hidden state
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
presents,
|
|
all_hidden_states,
|
|
all_self_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=presents,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
|
|
class BloomForCausalLM(BloomPreTrainedModel):
|
|
def __init__(self, config, weights):
|
|
super().__init__(config)
|
|
self.transformer = BloomModel(config, weights)
|
|
|
|
self.lm_head = SpeculativeHead.load(
|
|
config,
|
|
prefix="word_embeddings",
|
|
weights=weights,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
past_key_values: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> dict:
|
|
# only last token for input_ids if past is not None
|
|
if past_key_values:
|
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
|
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
|
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
|
past_key_values = self._convert_to_bloom_cache(past_key_values)
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"attention_mask": attention_mask,
|
|
}
|
|
)
|
|
return model_inputs
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
**deprecated_arguments,
|
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
"""
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
warnings.warn(
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
" passing `position_ids`.",
|
|
FutureWarning,
|
|
)
|
|
if len(deprecated_arguments) > 0:
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
transformer_outputs = self.transformer(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
logits, speculative_logits = self.lm_head(hidden_states)
|
|
loss = None
|
|
|
|
if not return_dict:
|
|
output = (lm_logits,) + transformer_outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return (
|
|
CausalLMOutputWithCrossAttentions(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
),
|
|
speculative_logits,
|
|
)
|