mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'huggingface:main' into main
This commit is contained in:
commit
073c1a884d
@ -174,13 +174,25 @@ class SuperLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
def __init__(self, linear, process_group):
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.should_gather = should_gather
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
|
if weights.process_group.size() > 1:
|
||||||
|
try:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
should_gather = True
|
||||||
|
except AssertionError:
|
||||||
|
# If the vocab size is not divisible by number of shards
|
||||||
|
# just load the entire thing.
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
else:
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
@ -190,13 +202,14 @@ class TensorParallelHead(SuperLayer):
|
|||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None, quantize=quantize),
|
get_linear(weight, bias=None, quantize=quantize),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
|
should_gather=should_gather,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
world_size = self.process_group.size()
|
if not self.should_gather:
|
||||||
if world_size == 1:
|
|
||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
|
world_size = self.process_group.size()
|
||||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
out_dim = self.linear.weight.shape[0]
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
@ -277,7 +290,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
class TensorParallelEmbedding(nn.Module):
|
class TensorParallelEmbedding(nn.Module):
|
||||||
def __init__(self, prefix: str, weights, reduce=True):
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
||||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
@ -69,7 +69,7 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(self, tensor_name: str, dim: int):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
@ -81,10 +81,6 @@ class Weights:
|
|||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
assert (
|
|
||||||
size % world_size == 0
|
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
tensor = slice_[start:stop]
|
tensor = slice_[start:stop]
|
||||||
elif dim == 1:
|
elif dim == 1:
|
||||||
@ -98,6 +94,17 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def get_sharded(self, tensor_name: str, dim: int):
|
||||||
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
|
f = self._get_handle(filename)
|
||||||
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
size = slice_.get_shape()[dim]
|
||||||
|
assert (
|
||||||
|
size % world_size == 0
|
||||||
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user