Fix GQA llama + AWQ

This commit is contained in:
Nicolas Patry 2023-09-26 06:26:23 +00:00
parent c5de7cd886
commit 1ab173a260

View File

@ -179,7 +179,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize != "gptq": if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads