2024-05-13 10:44:30 +00:00
|
|
|
import torch
|
|
|
|
from torch.nn import functional as F
|
2024-05-28 09:51:31 +00:00
|
|
|
from typing import Iterable, List
|
2024-05-13 10:44:30 +00:00
|
|
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
2024-06-25 11:20:57 +00:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2024-06-25 10:21:29 +00:00
|
|
|
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
import intel_extension_for_pytorch as ipex
|
2024-05-28 09:51:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
class LayerConcat(torch.nn.Module):
|
|
|
|
"""
|
|
|
|
Apply multiple layers to the input and concatenate their
|
|
|
|
outputs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
|
|
|
|
"""
|
|
|
|
`dim` is the dimension along which layer outputs are concatenated.
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.layers = layers
|
|
|
|
self.dim = dim
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
outputs = [layer(x) for layer in self.layers]
|
|
|
|
return torch.cat(outputs, self.dim)
|
2024-05-13 10:44:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SuperLayer(torch.nn.Module):
|
|
|
|
def __init__(self, linear):
|
|
|
|
super().__init__()
|
|
|
|
self.linear = linear
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.linear.forward(x)
|
|
|
|
|
|
|
|
|
|
|
|
class TensorParallelHead(SuperLayer):
|
|
|
|
def __init__(self, linear, process_group, should_gather: bool):
|
|
|
|
super().__init__(linear)
|
|
|
|
self.process_group = process_group
|
|
|
|
self.should_gather = should_gather
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load(config, prefix: str, weights):
|
2024-05-28 09:51:31 +00:00
|
|
|
if config.quantize == "exl2":
|
|
|
|
try:
|
|
|
|
# If the piece and LM head embeddings are shared, we have
|
|
|
|
# non-quantized weights...
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
2024-07-26 14:29:09 +00:00
|
|
|
except Exception:
|
2024-05-28 09:51:31 +00:00
|
|
|
# ...otherwise they are quantized.
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_weights_col(prefix)
|
2024-05-28 09:51:31 +00:00
|
|
|
should_gather = weights.process_group.size() > 1
|
|
|
|
elif weights.process_group.size() > 1:
|
2024-05-13 10:44:30 +00:00
|
|
|
try:
|
|
|
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
|
|
|
should_gather = True
|
|
|
|
except AssertionError:
|
|
|
|
# If the vocab size is not divisible by number of shards
|
|
|
|
# just load the entire thing.
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
|
|
should_gather = False
|
|
|
|
else:
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
|
|
should_gather = False
|
|
|
|
|
|
|
|
return TensorParallelHead(
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
get_linear(weight, bias=None),
|
2024-05-13 10:44:30 +00:00
|
|
|
process_group=weights.process_group,
|
|
|
|
should_gather=should_gather,
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
if not self.should_gather:
|
|
|
|
return super().forward(input)
|
|
|
|
|
|
|
|
world_size = self.process_group.size()
|
|
|
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
|
|
|
out_dim = self.linear.weight.shape[0]
|
|
|
|
|
|
|
|
if input.shape[0] == 1:
|
|
|
|
world_out = input.new_empty(1, out_dim * world_size)
|
|
|
|
local_out = input.new_empty(1, out_dim)
|
|
|
|
gather_input = local_out
|
|
|
|
else:
|
|
|
|
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
|
|
|
gather_input = input.new_empty(out_dim, input.shape[0])
|
|
|
|
local_out = gather_input.T
|
|
|
|
|
|
|
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
ipex.distributed.all_gather_into_tensor(
|
|
|
|
world_out, gather_input, group=self.process_group
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
torch.distributed.all_gather_into_tensor(
|
|
|
|
world_out, gather_input, group=self.process_group
|
|
|
|
)
|
2024-05-13 10:44:30 +00:00
|
|
|
|
|
|
|
if input.shape[0] == 1:
|
|
|
|
return world_out
|
|
|
|
return world_out.T
|
|
|
|
|
|
|
|
output = super().forward(input)
|
|
|
|
world_output = [
|
|
|
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
|
|
|
]
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
|
|
|
else:
|
|
|
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
2024-05-13 10:44:30 +00:00
|
|
|
world_output = torch.cat(world_output, dim=-1)
|
|
|
|
return world_output
|
|
|
|
|
|
|
|
|
|
|
|
class TensorParallelColumnLinear(SuperLayer):
|
|
|
|
@classmethod
|
|
|
|
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
|
|
|
"""Specific method when the QKV was joined after the fact"""
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_weights_col_packed_gate_up(prefix)
|
2024-05-13 10:44:30 +00:00
|
|
|
if bias:
|
|
|
|
raise NotImplementedError("packed_gate_up only implemented without bias")
|
|
|
|
else:
|
|
|
|
bias = None
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
linear = get_linear(weight, bias)
|
2024-05-13 10:44:30 +00:00
|
|
|
return cls(linear)
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-10 07:22:29 +00:00
|
|
|
def load_qkv(
|
|
|
|
cls,
|
|
|
|
config,
|
|
|
|
prefix: str,
|
|
|
|
weights,
|
|
|
|
bias: bool,
|
|
|
|
num_heads: int,
|
|
|
|
num_key_value_heads: int,
|
|
|
|
):
|
2024-05-13 10:44:30 +00:00
|
|
|
"""Specific method when the QKV was joined after the fact"""
|
2024-06-10 07:22:29 +00:00
|
|
|
weight = weights.get_weights_col_packed_qkv(
|
|
|
|
prefix,
|
|
|
|
num_heads=num_heads,
|
|
|
|
num_key_value_heads=num_key_value_heads,
|
|
|
|
)
|
2024-05-13 10:44:30 +00:00
|
|
|
if bias:
|
|
|
|
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
|
|
|
else:
|
|
|
|
bias = None
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
linear = get_linear(weight, bias)
|
2024-05-13 10:44:30 +00:00
|
|
|
return cls(linear)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_weights_col(prefix)
|
2024-05-13 10:44:30 +00:00
|
|
|
if bias:
|
2024-05-28 09:51:31 +00:00
|
|
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
2024-05-13 10:44:30 +00:00
|
|
|
else:
|
|
|
|
bias = None
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
linear = get_linear(weight, bias)
|
2024-05-13 10:44:30 +00:00
|
|
|
return cls(linear)
|
|
|
|
|
2024-05-28 09:51:31 +00:00
|
|
|
@classmethod
|
|
|
|
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
|
|
|
if config.quantize == "exl2":
|
|
|
|
linears = []
|
|
|
|
for prefix in prefixes:
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_weights_col(prefix)
|
2024-05-28 09:51:31 +00:00
|
|
|
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
linears.append(get_linear(weight, b))
|
2024-05-28 09:51:31 +00:00
|
|
|
linear = LayerConcat(linears)
|
|
|
|
else:
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
2024-05-28 09:51:31 +00:00
|
|
|
if bias:
|
|
|
|
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
|
|
|
bias = torch.cat(b, dim=dim)
|
|
|
|
else:
|
|
|
|
bias = None
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
linear = get_linear(weight, bias)
|
2024-05-28 09:51:31 +00:00
|
|
|
return cls(linear)
|
|
|
|
|
2024-05-13 10:44:30 +00:00
|
|
|
|
|
|
|
class TensorParallelRowLinear(SuperLayer):
|
|
|
|
def __init__(self, linear, process_group):
|
|
|
|
super().__init__(linear)
|
|
|
|
self.process_group = process_group
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
2024-07-09 18:04:03 +00:00
|
|
|
weight = weights.get_weights_row(prefix)
|
2024-05-13 10:44:30 +00:00
|
|
|
|
|
|
|
if bias and weights.process_group.rank() == 0:
|
|
|
|
# Rank is only on the first rank process
|
|
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
|
|
else:
|
|
|
|
bias = None
|
|
|
|
return cls(
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
get_linear(weight, bias),
|
2024-05-13 10:44:30 +00:00
|
|
|
process_group=weights.process_group,
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
|
|
|
out = super().forward(input)
|
|
|
|
if self.process_group.size() > 1 and reduce:
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
|
|
|
else:
|
|
|
|
torch.distributed.all_reduce(out, group=self.process_group)
|
2024-05-13 10:44:30 +00:00
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class TensorParallelEmbedding(torch.nn.Module):
|
|
|
|
def __init__(self, prefix: str, weights, reduce=True):
|
|
|
|
super().__init__()
|
|
|
|
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
|
|
|
|
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
|
|
|
|
|
|
|
process_group = weights.process_group
|
|
|
|
|
|
|
|
world_size = process_group.size()
|
|
|
|
rank = process_group.rank()
|
|
|
|
|
|
|
|
block_size = (num_embeddings + world_size - 1) // world_size
|
|
|
|
self.min_id = rank * block_size
|
|
|
|
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
|
|
|
self.null_idx = weight.shape[
|
|
|
|
0
|
|
|
|
] # Usually block_size, might be less in non even vocab_size.
|
|
|
|
self.process_group = weights.process_group
|
|
|
|
self.reduce = reduce
|
|
|
|
|
|
|
|
"""Additional 0 entry used for masking"""
|
|
|
|
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
|
|
|
# translate for [0, self.max_id - self.min_id[
|
|
|
|
input = torch.where(
|
|
|
|
(self.min_id > input) | (input >= self.max_id),
|
|
|
|
self.null_idx,
|
|
|
|
input - self.min_id,
|
|
|
|
)
|
|
|
|
out = torch.nn.functional.embedding(input, self.weight)
|
|
|
|
if self.reduce and self.process_group.size() > 1:
|
2024-06-25 11:20:57 +00:00
|
|
|
if SYSTEM == "ipex":
|
2024-06-25 10:21:29 +00:00
|
|
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
|
|
|
else:
|
|
|
|
torch.distributed.all_reduce(out, group=self.process_group)
|
2024-05-13 10:44:30 +00:00
|
|
|
return out
|