mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow.
84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
import torch
|
|
from text_generation_server.layers import (
|
|
TensorParallelEmbedding,
|
|
)
|
|
from text_generation_server.utils.weights import DefaultWeightsLoader
|
|
|
|
|
|
class ProcessGroup:
|
|
def __init__(self, rank: int, world_size: int):
|
|
self._rank = rank
|
|
self.world_size = world_size
|
|
|
|
def size(self) -> int:
|
|
return self.world_size
|
|
|
|
def rank(self) -> int:
|
|
return self._rank
|
|
|
|
|
|
class Weights:
|
|
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
|
|
self.weight = (
|
|
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
|
|
)
|
|
self.process_group = ProcessGroup(rank, world_size)
|
|
|
|
def get_partial_sharded(self, name: str, dim: int):
|
|
assert dim == 0
|
|
|
|
rank = self.process_group.rank()
|
|
world_size = self.process_group.size()
|
|
size = self.weight.shape[dim]
|
|
|
|
block_size = (size + world_size - 1) // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
return self.weight[start:stop]
|
|
|
|
def get_shape(self, name: str):
|
|
return self.weight.shape
|
|
|
|
|
|
def test_weight_hub_files_offline_error():
|
|
|
|
vocab_size = 17
|
|
weights = Weights(
|
|
rank=0,
|
|
world_size=1,
|
|
vocab_size=vocab_size,
|
|
hidden_dim=256,
|
|
)
|
|
embeddings = TensorParallelEmbedding("", weights)
|
|
|
|
input_ids = torch.arange(vocab_size)
|
|
output = embeddings.forward(input_ids)
|
|
assert embeddings.min_id == 0
|
|
assert embeddings.max_id == 17
|
|
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
|
|
|
|
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
|
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
|
|
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
|
|
assert embeddings_0_2.min_id == 0
|
|
assert embeddings_0_2.max_id == 9
|
|
torch.testing.assert_close(
|
|
embeddings_0_2.weight,
|
|
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
|
|
.view(10, 256)
|
|
.float(),
|
|
)
|
|
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
|
|
assert embeddings_1_2.min_id == 9
|
|
assert embeddings_1_2.max_id == 17
|
|
torch.testing.assert_close(
|
|
embeddings_1_2.weight,
|
|
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
|
|
.view(9, 256)
|
|
.float(),
|
|
)
|
|
output_tp_0 = embeddings_0_2.forward(input_ids)
|
|
output_tp_1 = embeddings_1_2.forward(input_ids)
|
|
|
|
torch.testing.assert_close(output, output_tp_0 + output_tp_1)
|