Merge branch 'huggingface:main' into main

This commit is contained in:
ssmi153 2023-07-12 23:06:19 +08:00 committed by GitHub
commit 073c1a884d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 10 deletions

View File

@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)
self.process_group = process_group self.process_group = process_group
self.should_gather = should_gather
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize == "gptq":
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group, process_group=weights.process_group,
should_gather=should_gather,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
world_size = self.process_group.size() if not self.should_gather:
if world_size == 1:
return super().forward(input) return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear): if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0] out_dim = self.linear.weight.shape[0]
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module): class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True): def __init__(self, prefix: str, weights, reduce=True):
super().__init__() super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0] num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group process_group = weights.process_group

View File

@ -69,7 +69,7 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_partial_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()
@ -81,10 +81,6 @@ class Weights:
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0: if dim == 0:
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif dim == 1: elif dim == 1:
@ -98,6 +94,17 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq": if quantize == "gptq":
try: try: