mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Enable padding before sharding for tp embedding for non-divisible embedding tables.
This commit is contained in:
parent
3ccb3bb0b5
commit
6af0b94046
@ -499,14 +499,17 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
class TensorParallelEmbedding(nn.Module):
|
class TensorParallelEmbedding(nn.Module):
|
||||||
def __init__(self, prefix: str, weights, reduce=True):
|
def __init__(self, prefix: str, weights, reduce=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
|
||||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
|
||||||
world_size = process_group.size()
|
world_size = process_group.size()
|
||||||
rank = process_group.rank()
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
weight, margin = weights.get_padded_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, pad_multiple=world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0] + margin
|
||||||
|
|
||||||
block_size = num_embeddings // world_size
|
block_size = num_embeddings // world_size
|
||||||
self.min_id = rank * block_size
|
self.min_id = rank * block_size
|
||||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||||
|
@ -109,6 +109,59 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
def get_padded_sharded(self, tensor_name: str, dim: int, pad_multiple: int | None = None):
|
||||||
|
|
||||||
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
|
f = self._get_handle(filename)
|
||||||
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
|
expanded_size = initial_size = slice_.get_shape()[dim]
|
||||||
|
pad_margin = 0
|
||||||
|
|
||||||
|
# Pad of tensor at given `dim` prior to sharding across `world_size`.
|
||||||
|
if pad_multiple is not None and pad_multiple > 0:
|
||||||
|
|
||||||
|
expanded_size = ((initial_size + pad_multiple - 1) // pad_multiple) * pad_multiple
|
||||||
|
pad_margin = expanded_size - initial_size
|
||||||
|
|
||||||
|
block_size = expanded_size // world_size
|
||||||
|
|
||||||
|
# Prevent excessive padding leading to suboptimal sharding.
|
||||||
|
if pad_margin >= block_size:
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"The chosen pad multiple of {pad_multiple} results in padded tensor that "
|
||||||
|
f"exceeds/fills the block boundary when sharding on {world_size} shards."
|
||||||
|
)
|
||||||
|
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = slice_[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = slice_[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
# Special case for gptq which shouldn't convert
|
||||||
|
# u4 which are disguised as int32
|
||||||
|
if tensor.dtype != torch.int32:
|
||||||
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
# Padding applied only to last sharded block.
|
||||||
|
if pad_margin > 0:
|
||||||
|
|
||||||
|
pad_direction = (0, 0, 0, pad_margin) if dim == 0 else (0, pad_margin)
|
||||||
|
tensor = torch.nn.functional.pad(tensor, pad_direction)
|
||||||
|
|
||||||
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
|
return tensor, pad_margin
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
|
Loading…
Reference in New Issue
Block a user