Fixing OOM on non sharded.

This commit is contained in:
Nicolas Patry 2023-07-12 12:46:02 +00:00
parent 6193512c4b
commit f764bc1b52

View File

@ -181,6 +181,7 @@ class TensorParallelHead(SuperLayer):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try: try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True should_gather = True
@ -189,6 +190,9 @@ class TensorParallelHead(SuperLayer):
# just load the entire thing. # just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize == "gptq":