feat(server): Implements sharding for non divisible vocab_size. (#583)

- The code is relatively easy (just disable the checks on Embedding and
Head)

This cannot be done in the same easy fashion for hidden_dim/head_dim.
It's relatively easy on some models (classic MHA) but it would make the
other
models (MQA) much more complex, and GPTQ quantization another quite
hairy piece
of code.
This commit is contained in:
Nicolas Patry 2023-07-12 16:43:31 +02:00 committed by GitHub
parent 2c4bf88268
commit 67347950b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 10 deletions

View File

@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)
self.process_group = process_group self.process_group = process_group
self.should_gather = should_gather
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0) if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize == "gptq":
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group, process_group=weights.process_group,
should_gather=should_gather,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size() if not self.should_gather:
if world_size == 1:
return super().forward(input) return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear): if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0] out_dim = self.linear.weight.shape[0]
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module): class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True): def __init__(self, prefix: str, weights, reduce=True):
super().__init__() super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0] num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group process_group = weights.process_group

View File

@ -69,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
@ -81,10 +81,6 @@ class Weights:
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0: if dim == 0:
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif dim == 1: elif dim == 1:
@ -98,6 +94,17 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq": if quantize == "gptq":
try: try: