diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 950744f6..699f4664 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -180,7 +180,7 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq":