mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fixed MPT sharding.
This commit is contained in:
parent
f33ad7ed98
commit
c62527a542
@ -3,8 +3,9 @@
|
||||
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -20,11 +21,35 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
def load_col(config, prefix, weights, bias):
|
||||
assert bias == False, NotImplementedError
|
||||
assert config.quantize != "gptq", NotImplementedError
|
||||
slice_ = weights._get_slice(f"{prefix}.weight")
|
||||
rank = weights.process_group.rank()
|
||||
size = weights.process_group.size()
|
||||
|
||||
h3, h = slice_.get_shape()
|
||||
block_size = h // size
|
||||
|
||||
q_part = slice_[rank * block_size : (rank + 1) * block_size]
|
||||
k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]
|
||||
v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]
|
||||
|
||||
weight = torch.cat([q_part, k_part, v_part], dim=0)
|
||||
if weight.dtype != torch.int32:
|
||||
weight = weight.to(dtype=weights.dtype)
|
||||
weight = weight.to(device=weights.device)
|
||||
bias = None
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
return TensorParallelColumnLinear(linear)
|
||||
|
||||
|
||||
def _reset_is_causal(
|
||||
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
||||
):
|
||||
@ -64,8 +89,6 @@ def scaled_multihead_dot_product_attention(
|
||||
past_key_value = (k, v)
|
||||
(b, _, s_q, d) = q.shape
|
||||
s_k = k.size(-1)
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1 / math.sqrt(d)
|
||||
attn_weight = q.matmul(k) * softmax_scale
|
||||
if attn_bias is not None:
|
||||
_s_q = max(0, attn_bias.size(2) - s_q)
|
||||
@ -296,11 +319,12 @@ class MultiheadAttention(nn.Module):
|
||||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
self.Wqkv = TensorParallelColumnLinear.load(
|
||||
self.n_heads = self.n_heads // weights.process_group.size()
|
||||
self.Wqkv = load_col(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
fuse_splits = (d_model, 2 * d_model)
|
||||
self.Wqkv._fused = (0, fuse_splits)
|
||||
# fuse_splits = (d_model, 2 * d_model)
|
||||
# self.Wqkv._fused = (0, fuse_splits)
|
||||
# if self.qk_ln:
|
||||
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||
# self.q_ln = layernorm_class(self.d_model, device=device)
|
||||
@ -333,6 +357,7 @@ class MultiheadAttention(nn.Module):
|
||||
if self.clip_qkv:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
(query, key, value) = qkv.chunk(3, dim=2)
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
if self.qk_ln:
|
||||
dtype = query.dtype
|
||||
@ -352,7 +377,8 @@ class MultiheadAttention(nn.Module):
|
||||
training=self.training,
|
||||
needs_weights=needs_weights,
|
||||
)
|
||||
return (self.out_proj(context), attn_weights, past_key_value)
|
||||
out = self.out_proj(context)
|
||||
return (out, attn_weights, past_key_value)
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
@ -362,37 +388,29 @@ class MultiQueryAttention(nn.Module):
|
||||
additive bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
attn_impl: str = "triton",
|
||||
clip_qkv: Optional[float] = None,
|
||||
qk_ln: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
attn_pdrop: float = 0.0,
|
||||
low_precision_layernorm: bool = False,
|
||||
verbose: int = 0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.attn_impl = attn_impl
|
||||
self.clip_qkv = clip_qkv
|
||||
self.qk_ln = qk_ln
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = d_model // n_heads
|
||||
self.softmax_scale = softmax_scale
|
||||
attn_impl = config.attn_config["attn_impl"]
|
||||
self.attn_impl = config.attn_config["attn_impl"]
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.d_model = config.d_model
|
||||
d_model = config.d_model
|
||||
self.n_heads = config.n_heads
|
||||
self.softmax_scale = config.attn_config["softmax_scale"]
|
||||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
self.attn_dropout_p = attn_pdrop
|
||||
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
# self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
||||
self.Wqkv = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
fuse_splits = (d_model, d_model + self.head_dim)
|
||||
self.Wqkv._fused = (0, fuse_splits)
|
||||
if self.qk_ln:
|
||||
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||
self.q_ln = layernorm_class(d_model, device=device)
|
||||
self.k_ln = layernorm_class(self.head_dim, device=device)
|
||||
# self.Wqkv._fused = (0, fuse_splits)
|
||||
# if self.qk_ln:
|
||||
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||
# self.q_ln = layernorm_class(d_model, device=device)
|
||||
# self.k_ln = layernorm_class(self.head_dim, device=device)
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -414,7 +432,12 @@ class MultiQueryAttention(nn.Module):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
||||
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=not config.no_bias,
|
||||
)
|
||||
# self.out_proj._is_residual = True
|
||||
|
||||
def forward(
|
||||
@ -553,6 +576,7 @@ class MPTMLP(nn.Module):
|
||||
class MPTBlock(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
||||
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
||||
if config.attn_config["attn_type"] != "multihead_attention":
|
||||
@ -726,6 +750,9 @@ class MPTModel(MPTPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
# config._validate_config()
|
||||
super().__init__(config)
|
||||
self.world_size = weights.process_group.size()
|
||||
self.rank = weights.process_group.rank()
|
||||
self.n_heads = config.n_heads
|
||||
self.attn_impl = config.attn_config["attn_impl"]
|
||||
self.prefix_lm = config.attn_config["prefix_lm"]
|
||||
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
||||
@ -810,6 +837,11 @@ class MPTModel(MPTPreTrainedModel):
|
||||
alibi=self.alibi,
|
||||
alibi_bias_max=self.alibi_bias_max,
|
||||
)
|
||||
assert self.n_heads % self.world_size == 0
|
||||
block_size = self.n_heads // self.world_size
|
||||
self.attn_bias = self.attn_bias[
|
||||
:, self.rank * block_size : (self.rank + 1) * block_size
|
||||
]
|
||||
self._attn_bias_initialized = True
|
||||
if self.attn_impl == "flash":
|
||||
return (self.attn_bias, attention_mask)
|
||||
|
@ -66,7 +66,7 @@ class MPTSharded(CausalLM):
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
Loading…
Reference in New Issue
Block a user