Doesn't affect LM_Head.

This commit is contained in:
Nicolas Patry 2023-07-11 12:51:13 +00:00
parent 906027ae58
commit 2e76727910

View File

@ -180,7 +180,7 @@ class TensorParallelHead(SuperLayer):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize == "gptq":