From df04b28bfc2e44a5f14cf96c42275e9e5f4f4401 Mon Sep 17 00:00:00 2001 From: Choon Meng Tan Date: Sun, 21 Jan 2024 12:18:20 +0800 Subject: [PATCH] Add Sealion MPT Support --- .../models/custom_modeling/mpt_modeling.py | 103 +++++++++++++----- 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 5ccf796d..dfd112f6 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -28,7 +28,6 @@ EPS = 1e-5 def load_col(config, prefix, weights, bias): - assert bias == False, NotImplementedError assert config.quantize != "gptq", NotImplementedError slice_ = weights._get_slice(f"{prefix}.weight") rank = weights.process_group.rank() @@ -45,7 +44,26 @@ def load_col(config, prefix, weights, bias): if weight.dtype != torch.int32: weight = weight.to(dtype=weights.dtype) weight = weight.to(device=weights.device) - bias = None + + if bias: + bias_slice_ = weights._get_slice(f"{prefix}.bias") + bias_rank = weights.process_group.rank() + bias_size = weights.process_group.size() + + bias_h = bias_slice_.get_shape() + bias_h = bias_h[0] + bias_block_size = bias_h // bias_size + + bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size] + bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size] + bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size] + + bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) + if bias.dtype != torch.int32: + bias = bias.to(dtype=weights.dtype) + bias = bias.to(device=weights.device) + else: + bias = None linear = get_linear(weight, bias, config.quantize) return TensorParallelColumnLinear(linear) @@ -330,7 +348,15 @@ class MultiheadAttention(nn.Module): config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) if self.qk_ln: - raise NotImplementedError("qk_ln is not supported") + if weights.process_group.size() > 1: + raise NotImplementedError("qk_ln is not supported for number of shards > 1") + bias = not config.no_bias + hidden_size = config.d_model + head_dim = hidden_size // self.n_heads + + norm_class = LPLayerNorm + self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights) + self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": @@ -581,12 +607,20 @@ class MPTBlock(nn.Module): f"""Not implemented attn {config.attn_config["attn_type"]}""" ) resid_pdrop = config.resid_pdrop - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) + if config.no_bias: + self.norm_1 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_1", weights=weights, eps=EPS + ) + self.norm_2 = nn.LayerNorm.load_no_bias( + prefix=f"{prefix}.norm_2", weights=weights, eps=EPS + ) + else: + self.norm_1 = nn.LayerNorm.load( + prefix=f"{prefix}.norm_1", weights=weights, eps=EPS + ) + self.norm_2 = nn.LayerNorm.load( + prefix=f"{prefix}.norm_2", weights=weights, eps=EPS + ) self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) self.resid_attn_dropout = nn.Dropout(resid_pdrop) @@ -616,15 +650,13 @@ class MPTBlock(nn.Module): def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) class LPLayerNorm(torch.nn.LayerNorm): @@ -635,6 +667,9 @@ class LPLayerNorm(torch.nn.LayerNorm): elementwise_affine=True, device=None, dtype=None, + bias: Optional[bool] = True, + prefix=None, + weights=None, ): super().__init__( normalized_shape=normalized_shape, @@ -642,7 +677,11 @@ class LPLayerNorm(torch.nn.LayerNorm): elementwise_affine=elementwise_affine, device=device, dtype=dtype, + bias=bias, ) + if weights is not None: + self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight")) + self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias")) def forward(self, x): module_device = x.device @@ -755,20 +794,23 @@ class MPTModel(MPTPreTrainedModel): ) self.wte = TensorParallelEmbedding("transformer.wte", weights) + if not self.alibi: - # self.wpe = torch.nn.Embedding( - # config.max_seq_len, config.d_model, device=config.init_device - # ) - raise RuntimeError("no alibi no supported") + self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.blocks = nn.ModuleList( [ MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) - self.norm_f = nn.LayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) + if config.no_bias: + self.norm_f = nn.LayerNorm.load_no_bias( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) + else: + self.norm_f = nn.LayerNorm.load( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None @@ -787,13 +829,15 @@ class MPTModel(MPTPreTrainedModel): if config.verbose: warnings.warn(f"Removing bias ({module.bias}) from {module}.") module.register_parameter("bias", None) - if config.verbose and config.verbose > 2: - print(self) + if hasattr(self.config, "verbose"): + if config.verbose and config.verbose > 2: + print(self) if "verbose" not in self.config.init_config: self.config.init_config["verbose"] = self.config.verbose if self.config.init_config["verbose"] > 1: init_fn_name = self.config.init_config["name"] warnings.warn(f"Using {init_fn_name} initialization.") + self.embedding_fraction = config.embedding_fraction @torch.no_grad() def _attn_bias( @@ -983,6 +1027,11 @@ class MPTModel(MPTPreTrainedModel): ) pos_emb = self.wpe(pos) x = tok_emb + pos_emb + if self.embedding_fraction != 1: + x = ( + x * self.embedding_fraction + + x.detach() * (1 - self.embedding_fraction) + ) (attn_bias, attention_mask) = self._attn_bias( device=x.device, dtype=torch.float32,