From d14eaacacab9ca3056a9d001d0ca2dc0a36edfde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 4 Jun 2024 19:37:49 +0200 Subject: [PATCH] Support GPTQ models with column-packed up/gate tensor (#2006) # What does this PR do? The GPTQ code path for column-packed packed tensors assumed that this is always a QKV matrix. However, models (e.g. Phi-3) can also have column-packed MLP up/gate matrices. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../text_generation_server/utils/weights.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2dfd80bf..71d67d82 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -121,24 +121,30 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str): + def _get_qweight(self, name: str, blocks: int): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" - single_size = total_size // 3 + assert ( + total_size % blocks == 0 + ), f"Prepacked quantized matrix is not divisible by {blocks}" + single_size = total_size // blocks world_size = self.process_group.size() rank = self.process_group.rank() assert ( single_size % world_size == 0 - ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size - q = slice_[:, start:stop] - k = slice_[:, start + single_size : stop + single_size] - v = slice_[:, start + 2 * single_size : stop + 2 * single_size] - weight = torch.cat([q, k, v], dim=1) + + weights = [] + for block in range(blocks): + weights.append( + slice_[:, start + block * single_size : stop + block * single_size] + ) + + weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight @@ -157,7 +163,7 @@ class Weights: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight") + qweight = self._get_qweight(f"{prefix}.qweight", blocks) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -165,8 +171,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros") - scales = self._get_qweight(f"{prefix}.scales") + qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) + scales = self._get_qweight(f"{prefix}.scales", blocks) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq":