Fixing OOM on non sharded.

This commit is contained in:
Nicolas Patry 2023-07-12 12:46:02 +00:00
parent 6193512c4b
commit f764bc1b52

View File

@ -181,12 +181,16 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False