diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e0362b8..4f65446e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -174,13 +174,25 @@ class SuperLayer(nn.Module): class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group + self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer): return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, + should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - if world_size == 1: + if not self.should_gather: return super().forward(input) + world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] @@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4f300fe7..afcbb9c3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -69,7 +69,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() @@ -81,10 +81,6 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert ( - size % world_size == 0 - ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - if dim == 0: tensor = slice_[start:stop] elif dim == 1: @@ -98,6 +94,17 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: