mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Removing commented things (raising proper errors instead).
This commit is contained in:
parent
b591527a6c
commit
ed0c5bd1ed
@ -323,12 +323,8 @@ class MultiheadAttention(nn.Module):
|
||||
self.Wqkv = load_col(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
# fuse_splits = (d_model, 2 * d_model)
|
||||
# self.Wqkv._fused = (0, fuse_splits)
|
||||
# if self.qk_ln:
|
||||
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||
# self.q_ln = layernorm_class(self.d_model, device=device)
|
||||
# self.k_ln = layernorm_class(self.d_model, device=device)
|
||||
if self.qk_ln:
|
||||
raise NotImplementedError("qk_ln is not supported")
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -406,11 +402,8 @@ class MultiQueryAttention(nn.Module):
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
fuse_splits = (d_model, d_model + self.head_dim)
|
||||
# self.Wqkv._fused = (0, fuse_splits)
|
||||
# if self.qk_ln:
|
||||
# layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
||||
# self.q_ln = layernorm_class(d_model, device=device)
|
||||
# self.k_ln = layernorm_class(self.head_dim, device=device)
|
||||
if self.qk_ln:
|
||||
raise NotImplementedError("qk_ln not supported")
|
||||
if self.attn_impl == "flash":
|
||||
self.attn_fn = flash_attn_fn
|
||||
elif self.attn_impl == "triton":
|
||||
@ -577,29 +570,11 @@ class MPTBlock(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
# norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
||||
# attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
||||
if config.attn_config["attn_type"] != "multihead_attention":
|
||||
raise NotImplementedError(
|
||||
f"""Not implemented attn {config.attn_config["attn_type"]}"""
|
||||
)
|
||||
resid_pdrop = config.resid_pdrop
|
||||
# self.norm_1 = norm_class(d_model, device=device)
|
||||
# self.attn = attn_class(
|
||||
# attn_impl=attn_config["attn_impl"],
|
||||
# clip_qkv=attn_config["clip_qkv"],
|
||||
# qk_ln=attn_config["qk_ln"],
|
||||
# softmax_scale=attn_config["softmax_scale"],
|
||||
# attn_pdrop=attn_config["attn_pdrop"],
|
||||
# d_model=d_model,
|
||||
# n_heads=n_heads,
|
||||
# verbose=verbose,
|
||||
# device=device,
|
||||
# )
|
||||
# self.norm_2 = norm_class(d_model, device=device)
|
||||
# self.ffn = MPTMLP(
|
||||
# d_model=d_model, expansion_ratio=expansion_ratio, device=device
|
||||
# )
|
||||
self.norm_1 = nn.LayerNorm.load_no_bias(
|
||||
prefix=f"{prefix}.norm_1", weights=weights, eps=EPS
|
||||
)
|
||||
|
@ -65,10 +65,6 @@ class MPTSharded(CausalLM):
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
# config = AutoConfig.from_pretrained(
|
||||
# # model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
# model_id, revision=revision, trust_remote_code=False
|
||||
# )
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user