From 6af0b94046ccd9b40cf568d18e0bd5edb56cc176 Mon Sep 17 00:00:00 2001 From: Pragaash Date: Sat, 20 Jan 2024 15:21:43 -0800 Subject: [PATCH] Enable padding before sharding for tp embedding for non-divisible embedding tables. --- server/text_generation_server/utils/layers.py | 9 ++-- .../text_generation_server/utils/weights.py | 53 +++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d4fa2559..246e34fd 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -499,14 +499,17 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) - num_embeddings = weights.get_shape(f"{prefix}.weight")[0] - process_group = weights.process_group world_size = process_group.size() rank = process_group.rank() + weight, margin = weights.get_padded_sharded( + f"{prefix}.weight", dim=0, pad_multiple=world_size + ) + + num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + margin + block_size = num_embeddings // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index c4e82a6d..ab39c547 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -109,6 +109,59 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_padded_sharded(self, tensor_name: str, dim: int, pad_multiple: int | None = None): + + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + + world_size = self.process_group.size() + rank = self.process_group.rank() + + expanded_size = initial_size = slice_.get_shape()[dim] + pad_margin = 0 + + # Pad of tensor at given `dim` prior to sharding across `world_size`. + if pad_multiple is not None and pad_multiple > 0: + + expanded_size = ((initial_size + pad_multiple - 1) // pad_multiple) * pad_multiple + pad_margin = expanded_size - initial_size + + block_size = expanded_size // world_size + + # Prevent excessive padding leading to suboptimal sharding. + if pad_margin >= block_size: + + raise ValueError( + f"The chosen pad multiple of {pad_multiple} results in padded tensor that " + f"exceeds/fills the block boundary when sharding on {world_size} shards." + ) + + start = rank * block_size + stop = (rank + 1) * block_size + + if dim == 0: + tensor = slice_[start:stop] + elif dim == 1: + tensor = slice_[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + # Special case for gptq which shouldn't convert + # u4 which are disguised as int32 + if tensor.dtype != torch.int32: + tensor = tensor.to(dtype=self.dtype) + + # Padding applied only to last sharded block. + if pad_margin > 0: + + pad_direction = (0, 0, 0, pad_margin) if dim == 0 else (0, pad_margin) + tensor = torch.nn.functional.pad(tensor, pad_direction) + + tensor = tensor.to(device=self.device) + + return tensor, pad_margin + def get_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename)