Fixing sealion support.

This commit is contained in:
Nicolas Patry 2024-01-26 11:04:18 +00:00
parent 10fce5bffd
commit 29a4baea59

View File

@ -348,15 +348,12 @@ class MultiheadAttention(nn.Module):
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
) )
if self.qk_ln: if self.qk_ln:
if weights.process_group.size() > 1:
raise NotImplementedError("qk_ln is not supported for number of shards > 1")
bias = not config.no_bias bias = not config.no_bias
hidden_size = config.d_model hidden_size = config.d_model
head_dim = hidden_size // self.n_heads head_dim = hidden_size // self.n_heads
norm_class = LPLayerNorm self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
self.q_ln = norm_class(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights) self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
self.k_ln = norm_class(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
if self.attn_impl == "flash": if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton": elif self.attn_impl == "triton":
@ -650,6 +647,7 @@ class MPTBlock(nn.Module):
def _cast_if_autocast_enabled(tensor): def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == "cuda": if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype() dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu": elif tensor.device.type == "cpu":
@ -657,6 +655,7 @@ def _cast_if_autocast_enabled(tensor):
else: else:
raise NotImplementedError() raise NotImplementedError()
return tensor.to(dtype=dtype) return tensor.to(dtype=dtype)
return tensor
class LPLayerNorm(torch.nn.LayerNorm): class LPLayerNorm(torch.nn.LayerNorm):
@ -680,8 +679,11 @@ class LPLayerNorm(torch.nn.LayerNorm):
bias=bias, bias=bias,
) )
if weights is not None: if weights is not None:
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight")) self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0))
self.bias = nn.Parameter(weights.get_tensor(f"{prefix}.bias")) if bias:
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
self.normalized_shape = self.weight.shape
def forward(self, x): def forward(self, x):
module_device = x.device module_device = x.device
@ -837,7 +839,6 @@ class MPTModel(MPTPreTrainedModel):
if self.config.init_config["verbose"] > 1: if self.config.init_config["verbose"] > 1:
init_fn_name = self.config.init_config["name"] init_fn_name = self.config.init_config["name"]
warnings.warn(f"Using {init_fn_name} initialization.") warnings.warn(f"Using {init_fn_name} initialization.")
self.embedding_fraction = config.embedding_fraction
@torch.no_grad() @torch.no_grad()
def _attn_bias( def _attn_bias(
@ -1027,11 +1028,6 @@ class MPTModel(MPTPreTrainedModel):
) )
pos_emb = self.wpe(pos) pos_emb = self.wpe(pos)
x = tok_emb + pos_emb x = tok_emb + pos_emb
if self.embedding_fraction != 1:
x = (
x * self.embedding_fraction
+ x.detach() * (1 - self.embedding_fraction)
)
(attn_bias, attention_mask) = self._attn_bias( (attn_bias, attention_mask) = self._attn_bias(
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,