diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py new file mode 100644 index 00000000..0a9fecd1 --- /dev/null +++ b/server/tests/utils/test_layers.py @@ -0,0 +1,64 @@ +import torch +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, +) + +class ProcessGroup: + def __init__(self, rank: int, world_size: int): + self._rank = rank + self.world_size = world_size + + def size(self)->int: + return self.world_size + + def rank(self)->int: + return self._rank + +class Weights: + def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int): + self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim) + self.process_group = ProcessGroup(rank, world_size) + + + def get_partial_sharded(self, name:str, dim: int): + assert dim == 0 + + rank = self.process_group.rank() + world_size = self.process_group.size() + size = self.weight.shape[dim] + + block_size = (size + world_size - 1) // world_size + start = rank * block_size + stop = (rank + 1) * block_size + return self.weight[start:stop] + + def get_shape(self, name: str): + return self.weight.shape + +def test_weight_hub_files_offline_error(): + + vocab_size= 17 + weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256) + embeddings = TensorParallelEmbedding("", weights) + + input_ids = torch.arange(vocab_size) + output = embeddings.forward(input_ids) + assert embeddings.min_id == 0 + assert embeddings.max_id == 17 + torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256)) + + weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256) + weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256) + embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False) + assert embeddings_0_2.min_id == 0 + assert embeddings_0_2.max_id == 9 + torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float()) + embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False) + assert embeddings_1_2.min_id == 9 + assert embeddings_1_2.max_id == 17 + torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float()) + output_tp_0 = embeddings_0_2.forward(input_ids) + output_tp_1 = embeddings_1_2.forward(input_ids) + + torch.testing.assert_close(output, output_tp_0 + output_tp_1) + diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d4fa2559..5a0de0d7 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -507,10 +507,10 @@ class TensorParallelEmbedding(nn.Module): world_size = process_group.size() rank = process_group.rank() - block_size = num_embeddings // world_size + block_size = (num_embeddings + world_size - 1) // world_size self.min_id = rank * block_size self.max_id = min(num_embeddings, (rank + 1) * block_size) - self.null_idx = block_size + self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size. self.process_group = weights.process_group self.reduce = reduce diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index c4e82a6d..186733f3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -92,7 +92,7 @@ class Weights: rank = self.process_group.rank() size = slice_.get_shape()[dim] - block_size = size // world_size + block_size = (size + world_size - 1) // world_size start = rank * block_size stop = (rank + 1) * block_size