Enabling non divisble vocab_size.

This commit is contained in:
Nicolas Patry 2023-07-11 12:37:25 +00:00
parent db4efbf4bc
commit 906027ae58
2 changed files with 15 additions and 7 deletions

View File

@ -180,7 +180,7 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
@ -277,7 +277,8 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
# weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group

View File

@ -69,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
@ -81,10 +81,6 @@ class Weights:
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
@ -98,6 +94,17 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try: