Fixed MPT sharding.

This commit is contained in:
Nicolas Patry 2023-06-30 21:46:44 +00:00
parent f33ad7ed98
commit c62527a542
2 changed files with 68 additions and 36 deletions

View File

@ -3,8 +3,9 @@
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
""" """
import math import math
import os
import warnings import warnings
from typing import List, Optional, Tuple, Union, Dict from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -20,11 +21,35 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, TensorParallelHead,
get_linear,
) )
EPS = 1e-5 EPS = 1e-5
def load_col(config, prefix, weights, bias):
assert bias == False, NotImplementedError
assert config.quantize != "gptq", NotImplementedError
slice_ = weights._get_slice(f"{prefix}.weight")
rank = weights.process_group.rank()
size = weights.process_group.size()
h3, h = slice_.get_shape()
block_size = h // size
q_part = slice_[rank * block_size : (rank + 1) * block_size]
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
weight = torch.cat([q_part, k_part, v_part], dim=0)
if weight.dtype != torch.int32:
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
bias = None
linear = get_linear(weight, bias, config.quantize)
return TensorParallelColumnLinear(linear)
def _reset_is_causal( def _reset_is_causal(
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
): ):
@ -64,8 +89,6 @@ def scaled_multihead_dot_product_attention(
past_key_value = (k, v) past_key_value = (k, v)
(b, _, s_q, d) = q.shape (b, _, s_q, d) = q.shape
s_k = k.size(-1) s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None: if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q) _s_q = max(0, attn_bias.size(2) - s_q)
@ -296,11 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"] self.attn_dropout_p = config.attn_config["attn_pdrop"]
self.Wqkv = TensorParallelColumnLinear.load( self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
) )
fuse_splits = (d_model, 2 * d_model) # fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits) # self.Wqkv._fused = (0, fuse_splits)
# if self.qk_ln: # if self.qk_ln:
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm # layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
# self.q_ln = layernorm_class(self.d_model, device=device) # self.q_ln = layernorm_class(self.d_model, device=device)
@ -333,6 +357,7 @@ class MultiheadAttention(nn.Module):
if self.clip_qkv: if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
(query, key, value) = qkv.chunk(3, dim=2) (query, key, value) = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask key_padding_mask = attention_mask
if self.qk_ln: if self.qk_ln:
dtype = query.dtype dtype = query.dtype
@ -352,7 +377,8 @@ class MultiheadAttention(nn.Module):
training=self.training, training=self.training,
needs_weights=needs_weights, needs_weights=needs_weights,
) )
return (self.out_proj(context), attn_weights, past_key_value) out = self.out_proj(context)
return (out, attn_weights, past_key_value)
class MultiQueryAttention(nn.Module): class MultiQueryAttention(nn.Module):
@ -362,37 +388,29 @@ class MultiQueryAttention(nn.Module):
additive bias. additive bias.
""" """
def __init__( def __init__(self, config, prefix, weights):
self,
d_model: int,
n_heads: int,
attn_impl: str = "triton",
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__() super().__init__()
self.attn_impl = attn_impl attn_impl = config.attn_config["attn_impl"]
self.clip_qkv = clip_qkv self.attn_impl = config.attn_config["attn_impl"]
self.qk_ln = qk_ln self.clip_qkv = config.attn_config["clip_qkv"]
self.d_model = d_model self.qk_ln = config.attn_config["qk_ln"]
self.n_heads = n_heads self.d_model = config.d_model
self.head_dim = d_model // n_heads d_model = config.d_model
self.softmax_scale = softmax_scale self.n_heads = config.n_heads
self.softmax_scale = config.attn_config["softmax_scale"]
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim) self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop self.attn_dropout_p = config.attn_config["attn_pdrop"]
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
fuse_splits = (d_model, d_model + self.head_dim) fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits) # self.Wqkv._fused = (0, fuse_splits)
if self.qk_ln: # if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm # layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device) # self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device) # self.k_ln = layernorm_class(self.head_dim, device=device)
if self.attn_impl == "flash": if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton": elif self.attn_impl == "triton":
@ -414,7 +432,12 @@ class MultiQueryAttention(nn.Module):
) )
else: else:
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
# self.out_proj._is_residual = True # self.out_proj._is_residual = True
def forward( def forward(
@ -553,6 +576,7 @@ class MPTMLP(nn.Module):
class MPTBlock(nn.Module): class MPTBlock(nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.prefix = prefix
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] # norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] # attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
if config.attn_config["attn_type"] != "multihead_attention": if config.attn_config["attn_type"] != "multihead_attention":
@ -726,6 +750,9 @@ class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, config, weights):
# config._validate_config() # config._validate_config()
super().__init__(config) super().__init__(config)
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.n_heads = config.n_heads
self.attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config["attn_impl"]
self.prefix_lm = config.attn_config["prefix_lm"] self.prefix_lm = config.attn_config["prefix_lm"]
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
@ -810,6 +837,11 @@ class MPTModel(MPTPreTrainedModel):
alibi=self.alibi, alibi=self.alibi,
alibi_bias_max=self.alibi_bias_max, alibi_bias_max=self.alibi_bias_max,
) )
assert self.n_heads % self.world_size == 0
block_size = self.n_heads // self.world_size
self.attn_bias = self.attn_bias[
:, self.rank * block_size : (self.rank + 1) * block_size
]
self._attn_bias_initialized = True self._attn_bias_initialized = True
if self.attn_impl == "flash": if self.attn_impl == "flash":
return (self.attn_bias, attention_mask) return (self.attn_bias, attention_mask)

View File

@ -66,7 +66,7 @@ class MPTSharded(CausalLM):
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,