Fixing OOM on non sharded.

This commit is contained in:
Nicolas Patry 2023-07-12 12:46:02 +00:00
parent 6193512c4b
commit f764bc1b52

View File

@ -181,12 +181,16 @@ class TensorParallelHead(SuperLayer):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
try: if weights.process_group.size() > 1:
weight = weights.get_sharded(f"{prefix}.weight", dim=0) try:
should_gather = True weight = weights.get_sharded(f"{prefix}.weight", dim=0)
except AssertionError: should_gather = True
# If the vocab size is not divisible by number of shards except AssertionError:
# just load the entire thing. # If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False