From 2c4bf882689fda969c84dfc86c4158032de10e1b Mon Sep 17 00:00:00 2001 From: ssmi153 <129111316+ssmi153@users.noreply.github.com> Date: Wed, 12 Jul 2023 20:17:35 +0800 Subject: [PATCH 1/2] fix(server): Bug fixes for GPTQ_BITS environment variable passthrough (#590) # What does this PR do? This fixes a typo and extends the GPTP_BITS environment variables through to the second method which requires the same logic. Please let me know if there's anything I've misunderstood in this change. Thanks @Narsil for the original fix. --- server/text_generation_server/utils/weights.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 39f66862..4f300fe7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -127,8 +127,8 @@ class Weights: try: import os - bits = int(os.getenv("GTPQ_BITS")) - groupsize = int(os.getenv("GTPQ_GROUPSIZE")) + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) except Exception: raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) @@ -149,8 +149,17 @@ class Weights: scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: From 67347950b7518efeb64c7f99ee360af685b53934 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 16:43:31 +0200 Subject: [PATCH 2/2] feat(server): Implements sharding for non divisible `vocab_size`. (#583) - The code is relatively easy (just disable the checks on Embedding and Head) This cannot be done in the same easy fashion for hidden_dim/head_dim. It's relatively easy on some models (classic MHA) but it would make the other models (MQA) much more complex, and GPTQ quantization another quite hairy piece of code. --- server/text_generation_server/utils/layers.py | 23 +++++++++++++++---- .../text_generation_server/utils/weights.py | 17 ++++++++++---- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e0362b8..4f65446e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -174,13 +174,25 @@ class SuperLayer(nn.Module): class TensorParallelHead(SuperLayer): - def __init__(self, linear, process_group): + def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) self.process_group = process_group + self.should_gather = should_gather @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer): return TensorParallelHead( get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, + should_gather=should_gather, ) def forward(self, input: torch.Tensor) -> torch.Tensor: - world_size = self.process_group.size() - if world_size == 1: + if not self.should_gather: return super().forward(input) + world_size = self.process_group.size() if len(input.shape) == 2 and isinstance(self.linear, FastLinear): out_dim = self.linear.weight.shape[0] @@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 4f300fe7..afcbb9c3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -69,7 +69,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() @@ -81,10 +81,6 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert ( - size % world_size == 0 - ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - if dim == 0: tensor = slice_[start:stop] elif dim == 1: @@ -98,6 +94,17 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: