From 3fb8979a6d3c7077be422bdbf998c07105094d24 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 13 Jun 2023 15:37:52 +0000 Subject: [PATCH] Re-enabling dim=dim in TensorParallelColumn because llama. --- server/text_generation_server/utils/layers.py | 4 ++-- server/text_generation_server/utils/weights.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index b866f091..6aed8885 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -221,11 +221,11 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): - weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize) + weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] - bias = torch.cat(b, dim=0) + bias = torch.cat(b, dim=dim) else: bias = None linear = get_linear(weight, bias, config.quantize) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index be47b15b..6d0d7c37 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -83,7 +83,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_col(self, prefixes: List[str], quantize: str): + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) @@ -102,7 +102,7 @@ class Weights: weight = (qweight, qzeros, scales, g_idx, bits, groupsize) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=1) + weight = torch.cat(w, dim=dim) return weight def get_multi_weights_row(self, prefix: str, quantize: str):