diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e0362b8..950744f6 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -180,7 +180,7 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) # GPTQ doesn't quantize heads (nor embeddings) if config.quantize == "gptq": @@ -277,7 +277,8 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + # weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 39f66862..fc0b95f4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -69,7 +69,7 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() @@ -81,10 +81,6 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert ( - size % world_size == 0 - ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - if dim == 0: tensor = slice_[start:stop] elif dim == 1: @@ -98,6 +94,17 @@ class Weights: tensor = tensor.to(device=self.device) return tensor + def get_sharded(self, tensor_name: str, dim: int): + filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) + world_size = self.process_group.size() + size = slice_.get_shape()[dim] + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" + return self.get_partial_sharded(tensor_name, dim) + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: