diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 52de0428..1c4cd210 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -3,8 +3,9 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math +import os import warnings -from typing import List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F @@ -20,11 +21,35 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelRowLinear, TensorParallelHead, + get_linear, ) EPS = 1e-5 +def load_col(config, prefix, weights, bias): + assert bias == False, NotImplementedError + assert config.quantize != "gptq", NotImplementedError + slice_ = weights._get_slice(f"{prefix}.weight") + rank = weights.process_group.rank() + size = weights.process_group.size() + + h3, h = slice_.get_shape() + block_size = h // size + + q_part = slice_[rank * block_size : (rank + 1) * block_size] + k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] + v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] + + weight = torch.cat([q_part, k_part, v_part], dim=0) + if weight.dtype != torch.int32: + weight = weight.to(dtype=weights.dtype) + weight = weight.to(device=weights.device) + bias = None + linear = get_linear(weight, bias, config.quantize) + return TensorParallelColumnLinear(linear) + + def _reset_is_causal( num_query_tokens: int, num_key_tokens: int, original_is_causal: bool ): @@ -64,8 +89,6 @@ def scaled_multihead_dot_product_attention( past_key_value = (k, v) (b, _, s_q, d) = q.shape s_k = k.size(-1) - if softmax_scale is None: - softmax_scale = 1 / math.sqrt(d) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - s_q) @@ -296,11 +319,12 @@ class MultiheadAttention(nn.Module): if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = config.attn_config["attn_pdrop"] - self.Wqkv = TensorParallelColumnLinear.load( + self.n_heads = self.n_heads // weights.process_group.size() + self.Wqkv = load_col( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) - fuse_splits = (d_model, 2 * d_model) - self.Wqkv._fused = (0, fuse_splits) + # fuse_splits = (d_model, 2 * d_model) + # self.Wqkv._fused = (0, fuse_splits) # if self.qk_ln: # layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm # self.q_ln = layernorm_class(self.d_model, device=device) @@ -333,6 +357,7 @@ class MultiheadAttention(nn.Module): if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.chunk(3, dim=2) + key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype @@ -352,7 +377,8 @@ class MultiheadAttention(nn.Module): training=self.training, needs_weights=needs_weights, ) - return (self.out_proj(context), attn_weights, past_key_value) + out = self.out_proj(context) + return (out, attn_weights, past_key_value) class MultiQueryAttention(nn.Module): @@ -362,37 +388,29 @@ class MultiQueryAttention(nn.Module): additive bias. """ - def __init__( - self, - d_model: int, - n_heads: int, - attn_impl: str = "triton", - clip_qkv: Optional[float] = None, - qk_ln: bool = False, - softmax_scale: Optional[float] = None, - attn_pdrop: float = 0.0, - low_precision_layernorm: bool = False, - verbose: int = 0, - device: Optional[str] = None, - ): + def __init__(self, config, prefix, weights): super().__init__() - self.attn_impl = attn_impl - self.clip_qkv = clip_qkv - self.qk_ln = qk_ln - self.d_model = d_model - self.n_heads = n_heads - self.head_dim = d_model // n_heads - self.softmax_scale = softmax_scale + attn_impl = config.attn_config["attn_impl"] + self.attn_impl = config.attn_config["attn_impl"] + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.d_model = config.d_model + d_model = config.d_model + self.n_heads = config.n_heads + self.softmax_scale = config.attn_config["softmax_scale"] if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = attn_pdrop - self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) + self.attn_dropout_p = config.attn_config["attn_pdrop"] + # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) + self.Wqkv = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias + ) fuse_splits = (d_model, d_model + self.head_dim) - self.Wqkv._fused = (0, fuse_splits) - if self.qk_ln: - layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm - self.q_ln = layernorm_class(d_model, device=device) - self.k_ln = layernorm_class(self.head_dim, device=device) + # self.Wqkv._fused = (0, fuse_splits) + # if self.qk_ln: + # layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm + # self.q_ln = layernorm_class(d_model, device=device) + # self.k_ln = layernorm_class(self.head_dim, device=device) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": @@ -414,7 +432,12 @@ class MultiQueryAttention(nn.Module): ) else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=not config.no_bias, + ) # self.out_proj._is_residual = True def forward( @@ -553,6 +576,7 @@ class MPTMLP(nn.Module): class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() + self.prefix = prefix # norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] # attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] if config.attn_config["attn_type"] != "multihead_attention": @@ -726,6 +750,9 @@ class MPTModel(MPTPreTrainedModel): def __init__(self, config, weights): # config._validate_config() super().__init__(config) + self.world_size = weights.process_group.size() + self.rank = weights.process_group.rank() + self.n_heads = config.n_heads self.attn_impl = config.attn_config["attn_impl"] self.prefix_lm = config.attn_config["prefix_lm"] self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] @@ -810,6 +837,11 @@ class MPTModel(MPTPreTrainedModel): alibi=self.alibi, alibi_bias_max=self.alibi_bias_max, ) + assert self.n_heads % self.world_size == 0 + block_size = self.n_heads // self.world_size + self.attn_bias = self.attn_bias[ + :, self.rank * block_size : (self.rank + 1) * block_size + ] self._attn_bias_initialized = True if self.attn_impl == "flash": return (self.attn_bias, attention_mask) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index b38f6218..889b3c95 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -66,7 +66,7 @@ class MPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - requires_padding=False, + requires_padding=True, dtype=dtype, device=device, rank=rank,