Add Sealion MPT Support

This commit is contained in:
Choon Meng Tan 2024-01-21 12:18:20 +08:00
parent 3ccb3bb0b5
commit df04b28bfc

View File

@ -28,7 +28,6 @@ EPS = 1e-5
def load_col(config, prefix, weights, bias):
assert bias == False, NotImplementedError
assert config.quantize != "gptq", NotImplementedError
slice_ = weights._get_slice(f"{prefix}.weight")
rank = weights.process_group.rank()
@ -45,7 +44,26 @@ def load_col(config, prefix, weights, bias):
if weight.dtype != torch.int32:
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
bias = None
if bias:
bias_slice_ = weights._get_slice(f"{prefix}.bias")
bias_rank = weights.process_group.rank()
bias_size = weights.process_group.size()
bias_h = bias_slice_.get_shape()
bias_h = bias_h[0]
bias_block_size = bias_h // bias_size
bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
if bias.dtype != torch.int32:
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return TensorParallelColumnLinear(linear)
@ -330,7 +348,15 @@ class MultiheadAttention(nn.Module):
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
if self.qk_ln:
raise NotImplementedError("qk_ln is not supported")
if weights.process_group.size() > 1:
raise NotImplementedError("qk_ln is not supported for number of shards > 1")
bias = not config.no_bias
hidden_size = config.d_model
head_dim = hidden_size // self.n_heads
norm_class = LPLayerNorm
self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -581,12 +607,20 @@ class MPTBlock(nn.Module):
f"""Not implemented attn {config.attn_config["attn_type"]}"""
)
resid_pdrop = config.resid_pdrop
self.norm_1 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
if config.no_bias:
self.norm_1 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
else:
self.norm_1 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)
self.norm_2 = nn.LayerNorm.load(
prefix=f"{prefix}.norm_2", weights=weights, eps=EPS
)
self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights)
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@ -616,15 +650,13 @@ class MPTBlock(nn.Module):
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
class LPLayerNorm(torch.nn.LayerNorm):
@ -635,6 +667,9 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=True,
device=None,
dtype=None,
bias: Optional[bool] = True,
prefix=None,
weights=None,
):
super().__init__(
normalized_shape=normalized_shape,
@ -642,7 +677,11 @@ class LPLayerNorm(torch.nn.LayerNorm):
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
bias=bias,
)
if weights is not None:
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"))
self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias"))
def forward(self, x):
module_device = x.device
@ -755,20 +794,23 @@ class MPTModel(MPTPreTrainedModel):
)
self.wte = TensorParallelEmbedding("transformer.wte", weights)
if not self.alibi:
# self.wpe = torch.nn.Embedding(
# config.max_seq_len, config.d_model, device=config.init_device
# )
raise RuntimeError("no alibi no supported")
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
for i in range(config.n_layers)
]
)
self.norm_f = nn.LayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
if config.no_bias:
self.norm_f = nn.LayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
else:
self.norm_f = nn.LayerNorm.load(
prefix="transformer.norm_f", weights=weights, eps=EPS
)
self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False
self.attn_bias = None
@ -787,13 +829,15 @@ class MPTModel(MPTPreTrainedModel):
if config.verbose:
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
module.register_parameter("bias", None)
if config.verbose and config.verbose > 2:
print(self)
if hasattr(self.config, "verbose"):
if config.verbose and config.verbose > 2:
print(self)
if "verbose" not in self.config.init_config:
self.config.init_config["verbose"] = self.config.verbose
if self.config.init_config["verbose"] > 1:
init_fn_name = self.config.init_config["name"]
warnings.warn(f"Using {init_fn_name} initialization.")
self.embedding_fraction = config.embedding_fraction
@torch.no_grad()
def _attn_bias(
@ -983,6 +1027,11 @@ class MPTModel(MPTPreTrainedModel):
)
pos_emb = self.wpe(pos)
x = tok_emb + pos_emb
if self.embedding_fraction != 1:
x = (
x * self.embedding_fraction
+ x.detach() * (1 - self.embedding_fraction)
)
(attn_bias, attention_mask) = self._attn_bias(
device=x.device,
dtype=torch.float32,