Fixed MPT sharding.

This commit is contained in:
Nicolas Patry 2023-06-30 21:46:44 +00:00
parent f33ad7ed98
commit c62527a542
2 changed files with 68 additions and 36 deletions

View File

@ -3,8 +3,9 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""
import math
import os
import warnings
from typing import List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -20,11 +21,35 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelHead,
get_linear,
)
EPS = 1e-5
def load_col(config, prefix, weights, bias):
assert bias == False, NotImplementedError
assert config.quantize != "gptq", NotImplementedError
slice_ = weights._get_slice(f"{prefix}.weight")
rank = weights.process_group.rank()
size = weights.process_group.size()
h3, h = slice_.get_shape()
block_size = h // size
q_part = slice_[rank * block_size : (rank + 1) * block_size]
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
weight = torch.cat([q_part, k_part, v_part], dim=0)
if weight.dtype != torch.int32:
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
bias = None
linear = get_linear(weight, bias, config.quantize)
return TensorParallelColumnLinear(linear)
def _reset_is_causal(
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
):
@ -64,8 +89,6 @@ def scaled_multihead_dot_product_attention(
past_key_value = (k, v)
(b, _, s_q, d) = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q)
@ -296,11 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
self.Wqkv = TensorParallelColumnLinear.load(
self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits)
# fuse_splits = (d_model, 2 * d_model)
# self.Wqkv._fused = (0, fuse_splits)
# if self.qk_ln:
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
# self.q_ln = layernorm_class(self.d_model, device=device)
@ -333,6 +357,7 @@ class MultiheadAttention(nn.Module):
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
(query, key, value) = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
dtype = query.dtype
@ -352,7 +377,8 @@ class MultiheadAttention(nn.Module):
training=self.training,
needs_weights=needs_weights,
)
return (self.out_proj(context), attn_weights, past_key_value)
out = self.out_proj(context)
return (out, attn_weights, past_key_value)
class MultiQueryAttention(nn.Module):
@ -362,37 +388,29 @@ class MultiQueryAttention(nn.Module):
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = "triton",
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
def __init__(self, config, prefix, weights):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.softmax_scale = softmax_scale
attn_impl = config.attn_config["attn_impl"]
self.attn_impl = config.attn_config["attn_impl"]
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.d_model = config.d_model
d_model = config.d_model
self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"]
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits)
if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device)
# self.Wqkv._fused = (0, fuse_splits)
# if self.qk_ln:
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
# self.q_ln = layernorm_class(d_model, device=device)
# self.k_ln = layernorm_class(self.head_dim, device=device)
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -414,7 +432,12 @@ class MultiQueryAttention(nn.Module):
)
else:
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
# self.out_proj._is_residual = True
def forward(
@ -553,6 +576,7 @@ class MPTMLP(nn.Module):
class MPTBlock(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.prefix = prefix
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
if config.attn_config["attn_type"] != "multihead_attention":
@ -726,6 +750,9 @@ class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights):
# config._validate_config()
super().__init__(config)
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.n_heads = config.n_heads
self.attn_impl = config.attn_config["attn_impl"]
self.prefix_lm = config.attn_config["prefix_lm"]
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
@ -810,6 +837,11 @@ class MPTModel(MPTPreTrainedModel):
alibi=self.alibi,
alibi_bias_max=self.alibi_bias_max,
)
assert self.n_heads % self.world_size == 0
block_size = self.n_heads // self.world_size
self.attn_bias = self.attn_bias[
:, self.rank * block_size : (self.rank + 1) * block_size
]
self._attn_bias_initialized = True
if self.attn_impl == "flash":
return (self.attn_bias, attention_mask)

View File

@ -66,7 +66,7 @@ class MPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,