From 63f03b4b7d93fe5e875f59159ef2c9a3be60cd04 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 09:03:16 +0000 Subject: [PATCH] Just don't shard LMHead if not divisible. --- server/text_generation_server/utils/layers.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 699f4664..ee1a86de 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -174,13 +174,21 @@ class SuperLayer(nn.Module): class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group + self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -190,13 +198,14 @@ class TensorParallelHead(SuperLayer): return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, + should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - if world_size == 1: + if not self.should_gather: return super().forward(input) + world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0]