mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
073c1a884d
@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.should_gather = should_gather
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
# If the vocab size is not divisible by number of shards
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=quantize),
|
||||
process_group=weights.process_group,
|
||||
should_gather=should_gather,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
world_size = self.process_group.size()
|
||||
if world_size == 1:
|
||||
if not self.should_gather:
|
||||
return super().forward(input)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||
out_dim = self.linear.weight.shape[0]
|
||||
|
||||
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -69,7 +69,7 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
@ -81,10 +81,6 @@ class Weights:
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
@ -98,6 +94,17 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user