Just don't shard LMHead if not divisible.

This commit is contained in:
Nicolas Patry 2023-07-12 09:03:16 +00:00
parent 2e76727910
commit 63f03b4b7d

View File

@ -174,13 +174,21 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)
self.process_group = process_group
self.should_gather = should_gather
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
@ -190,13 +198,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
should_gather=should_gather,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size()
if world_size == 1:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]