diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index dfd112f6..1a9aef74 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -348,15 +348,12 @@ class MultiheadAttention(nn.Module): config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) if self.qk_ln: - if weights.process_group.size() > 1: - raise NotImplementedError("qk_ln is not supported for number of shards > 1") bias = not config.no_bias hidden_size = config.d_model head_dim = hidden_size // self.n_heads - norm_class = LPLayerNorm - self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights) - self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights) + self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights) + self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": @@ -650,13 +647,15 @@ class MPTBlock(nn.Module): def _cast_if_autocast_enabled(tensor): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) + if torch.is_autocast_enabled(): + if tensor.device.type == "cuda": + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == "cpu": + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor class LPLayerNorm(torch.nn.LayerNorm): @@ -680,8 +679,11 @@ class LPLayerNorm(torch.nn.LayerNorm): bias=bias, ) if weights is not None: - self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight")) - self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias")) + self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) + if bias: + self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) + self.normalized_shape = self.weight.shape + def forward(self, x): module_device = x.device @@ -837,7 +839,6 @@ class MPTModel(MPTPreTrainedModel): if self.config.init_config["verbose"] > 1: init_fn_name = self.config.init_config["name"] warnings.warn(f"Using {init_fn_name} initialization.") - self.embedding_fraction = config.embedding_fraction @torch.no_grad() def _attn_bias( @@ -1027,11 +1028,6 @@ class MPTModel(MPTPreTrainedModel): ) pos_emb = self.wpe(pos) x = tok_emb + pos_emb - if self.embedding_fraction != 1: - x = ( - x * self.embedding_fraction - + x.detach() * (1 - self.embedding_fraction) - ) (attn_bias, attention_mask) = self._attn_bias( device=x.device, dtype=torch.float32,