mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixed MPT sharding.
This commit is contained in:
parent
f33ad7ed98
commit
c62527a542
@ -3,8 +3,9 @@
|
|||||||
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union, Dict
|
from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -20,11 +21,35 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
EPS = 1e-5
|
EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def load_col(config, prefix, weights, bias):
|
||||||
|
assert bias == False, NotImplementedError
|
||||||
|
assert config.quantize != "gptq", NotImplementedError
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.weight")
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
size = weights.process_group.size()
|
||||||
|
|
||||||
|
h3, h = slice_.get_shape()
|
||||||
|
block_size = h // size
|
||||||
|
|
||||||
|
q_part = slice_[rank * block_size : (rank + 1) * block_size]
|
||||||
|
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
|
||||||
|
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
|
||||||
|
|
||||||
|
weight = torch.cat([q_part, k_part, v_part], dim=0)
|
||||||
|
if weight.dtype != torch.int32:
|
||||||
|
weight = weight.to(dtype=weights.dtype)
|
||||||
|
weight = weight.to(device=weights.device)
|
||||||
|
bias = None
|
||||||
|
linear = get_linear(weight, bias, config.quantize)
|
||||||
|
return TensorParallelColumnLinear(linear)
|
||||||
|
|
||||||
|
|
||||||
def _reset_is_causal(
|
def _reset_is_causal(
|
||||||
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
||||||
):
|
):
|
||||||
@ -64,8 +89,6 @@ def scaled_multihead_dot_product_attention(
|
|||||||
past_key_value = (k, v)
|
past_key_value = (k, v)
|
||||||
(b, _, s_q, d) = q.shape
|
(b, _, s_q, d) = q.shape
|
||||||
s_k = k.size(-1)
|
s_k = k.size(-1)
|
||||||
if softmax_scale is None:
|
|
||||||
softmax_scale = 1 / math.sqrt(d)
|
|
||||||
attn_weight = q.matmul(k) * softmax_scale
|
attn_weight = q.matmul(k) * softmax_scale
|
||||||
if attn_bias is not None:
|
if attn_bias is not None:
|
||||||
_s_q = max(0, attn_bias.size(2) - s_q)
|
_s_q = max(0, attn_bias.size(2) - s_q)
|
||||||
@ -296,11 +319,12 @@ class MultiheadAttention(nn.Module):
|
|||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||||
self.Wqkv = TensorParallelColumnLinear.load(
|
self.n_heads = self.n_heads // weights.process_group.size()
|
||||||
|
self.Wqkv = load_col(
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
fuse_splits = (d_model, 2 * d_model)
|
# fuse_splits = (d_model, 2 * d_model)
|
||||||
self.Wqkv._fused = (0, fuse_splits)
|
# self.Wqkv._fused = (0, fuse_splits)
|
||||||
# if self.qk_ln:
|
# if self.qk_ln:
|
||||||
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||||
# self.q_ln = layernorm_class(self.d_model, device=device)
|
# self.q_ln = layernorm_class(self.d_model, device=device)
|
||||||
@ -333,6 +357,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
if self.clip_qkv:
|
if self.clip_qkv:
|
||||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
(query, key, value) = qkv.chunk(3, dim=2)
|
(query, key, value) = qkv.chunk(3, dim=2)
|
||||||
|
|
||||||
key_padding_mask = attention_mask
|
key_padding_mask = attention_mask
|
||||||
if self.qk_ln:
|
if self.qk_ln:
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
@ -352,7 +377,8 @@ class MultiheadAttention(nn.Module):
|
|||||||
training=self.training,
|
training=self.training,
|
||||||
needs_weights=needs_weights,
|
needs_weights=needs_weights,
|
||||||
)
|
)
|
||||||
return (self.out_proj(context), attn_weights, past_key_value)
|
out = self.out_proj(context)
|
||||||
|
return (out, attn_weights, past_key_value)
|
||||||
|
|
||||||
|
|
||||||
class MultiQueryAttention(nn.Module):
|
class MultiQueryAttention(nn.Module):
|
||||||
@ -362,37 +388,29 @@ class MultiQueryAttention(nn.Module):
|
|||||||
additive bias.
|
additive bias.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config, prefix, weights):
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_heads: int,
|
|
||||||
attn_impl: str = "triton",
|
|
||||||
clip_qkv: Optional[float] = None,
|
|
||||||
qk_ln: bool = False,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
attn_pdrop: float = 0.0,
|
|
||||||
low_precision_layernorm: bool = False,
|
|
||||||
verbose: int = 0,
|
|
||||||
device: Optional[str] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn_impl = attn_impl
|
attn_impl = config.attn_config["attn_impl"]
|
||||||
self.clip_qkv = clip_qkv
|
self.attn_impl = config.attn_config["attn_impl"]
|
||||||
self.qk_ln = qk_ln
|
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||||
self.d_model = d_model
|
self.qk_ln = config.attn_config["qk_ln"]
|
||||||
self.n_heads = n_heads
|
self.d_model = config.d_model
|
||||||
self.head_dim = d_model // n_heads
|
d_model = config.d_model
|
||||||
self.softmax_scale = softmax_scale
|
self.n_heads = config.n_heads
|
||||||
|
self.softmax_scale = config.attn_config["softmax_scale"]
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||||
self.attn_dropout_p = attn_pdrop
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||||
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||||
|
self.Wqkv = TensorParallelColumnLinear.load(
|
||||||
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
|
)
|
||||||
fuse_splits = (d_model, d_model + self.head_dim)
|
fuse_splits = (d_model, d_model + self.head_dim)
|
||||||
self.Wqkv._fused = (0, fuse_splits)
|
# self.Wqkv._fused = (0, fuse_splits)
|
||||||
if self.qk_ln:
|
# if self.qk_ln:
|
||||||
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||||
self.q_ln = layernorm_class(d_model, device=device)
|
# self.q_ln = layernorm_class(d_model, device=device)
|
||||||
self.k_ln = layernorm_class(self.head_dim, device=device)
|
# self.k_ln = layernorm_class(self.head_dim, device=device)
|
||||||
if self.attn_impl == "flash":
|
if self.attn_impl == "flash":
|
||||||
self.attn_fn = flash_attn_fn
|
self.attn_fn = flash_attn_fn
|
||||||
elif self.attn_impl == "triton":
|
elif self.attn_impl == "triton":
|
||||||
@ -414,7 +432,12 @@ class MultiQueryAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
||||||
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=not config.no_bias,
|
||||||
|
)
|
||||||
# self.out_proj._is_residual = True
|
# self.out_proj._is_residual = True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -553,6 +576,7 @@ class MPTMLP(nn.Module):
|
|||||||
class MPTBlock(nn.Module):
|
class MPTBlock(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.prefix = prefix
|
||||||
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
||||||
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
||||||
if config.attn_config["attn_type"] != "multihead_attention":
|
if config.attn_config["attn_type"] != "multihead_attention":
|
||||||
@ -726,6 +750,9 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
# config._validate_config()
|
# config._validate_config()
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
self.world_size = weights.process_group.size()
|
||||||
|
self.rank = weights.process_group.rank()
|
||||||
|
self.n_heads = config.n_heads
|
||||||
self.attn_impl = config.attn_config["attn_impl"]
|
self.attn_impl = config.attn_config["attn_impl"]
|
||||||
self.prefix_lm = config.attn_config["prefix_lm"]
|
self.prefix_lm = config.attn_config["prefix_lm"]
|
||||||
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
||||||
@ -810,6 +837,11 @@ class MPTModel(MPTPreTrainedModel):
|
|||||||
alibi=self.alibi,
|
alibi=self.alibi,
|
||||||
alibi_bias_max=self.alibi_bias_max,
|
alibi_bias_max=self.alibi_bias_max,
|
||||||
)
|
)
|
||||||
|
assert self.n_heads % self.world_size == 0
|
||||||
|
block_size = self.n_heads // self.world_size
|
||||||
|
self.attn_bias = self.attn_bias[
|
||||||
|
:, self.rank * block_size : (self.rank + 1) * block_size
|
||||||
|
]
|
||||||
self._attn_bias_initialized = True
|
self._attn_bias_initialized = True
|
||||||
if self.attn_impl == "flash":
|
if self.attn_impl == "flash":
|
||||||
return (self.attn_bias, attention_mask)
|
return (self.attn_bias, attention_mask)
|
||||||
|
@ -66,7 +66,7 @@ class MPTSharded(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
Loading…
Reference in New Issue
Block a user