Removing commented things (raising proper errors instead).

This commit is contained in:
Nicolas Patry 2023-07-03 08:42:26 +00:00
parent b591527a6c
commit ed0c5bd1ed
2 changed files with 4 additions and 33 deletions

View File

@ -323,12 +323,8 @@ class MultiheadAttention(nn.Module):
self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
# fuse_splits = (d_model, 2 * d_model)
# self.Wqkv._fused = (0, fuse_splits)
# if self.qk_ln:
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
# self.q_ln = layernorm_class(self.d_model, device=device)
# self.k_ln = layernorm_class(self.d_model, device=device)
if self.qk_ln:
raise NotImplementedError("qk_ln is not supported")
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -406,11 +402,8 @@ class MultiQueryAttention(nn.Module):
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
fuse_splits = (d_model, d_model + self.head_dim)
# self.Wqkv._fused = (0, fuse_splits)
# if self.qk_ln:
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
# self.q_ln = layernorm_class(d_model, device=device)
# self.k_ln = layernorm_class(self.head_dim, device=device)
if self.qk_ln:
raise NotImplementedError("qk_ln not supported")
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -577,29 +570,11 @@ class MPTBlock(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.prefix = prefix
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
if config.attn_config["attn_type"] != "multihead_attention":
raise NotImplementedError(
f"""Not implemented attn {config.attn_config["attn_type"]}"""
)
resid_pdrop = config.resid_pdrop
# self.norm_1 = norm_class(d_model, device=device)
# self.attn = attn_class(
# attn_impl=attn_config["attn_impl"],
# clip_qkv=attn_config["clip_qkv"],
# qk_ln=attn_config["qk_ln"],
# softmax_scale=attn_config["softmax_scale"],
# attn_pdrop=attn_config["attn_pdrop"],
# d_model=d_model,
# n_heads=n_heads,
# verbose=verbose,
# device=device,
# )
# self.norm_2 = norm_class(d_model, device=device)
# self.ffn = MPTMLP(
# d_model=d_model, expansion_ratio=expansion_ratio, device=device
# )
self.norm_1 = nn.LayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
)

View File

@ -65,10 +65,6 @@ class MPTSharded(CausalLM):
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
# config = AutoConfig.from_pretrained(
# # model_id, revision=revision, trust_remote_code=trust_remote_code
# model_id, revision=revision, trust_remote_code=False
# )
torch.distributed.barrier(group=self.process_group)