Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
from dataclasses import dataclass
|
2024-07-11 14:03:26 +00:00
|
|
|
|
2024-05-13 10:44:30 +00:00
|
|
|
import torch
|
2024-07-11 14:03:26 +00:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
from text_generation_server.utils.weights import Weight
|
2024-07-11 14:03:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_fp8_linear() -> torch.nn.Module:
|
|
|
|
"""
|
|
|
|
Return an FP8 linear `Module` that is compatible with the current system.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if SYSTEM == "cuda":
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
if major == 8 and minor < 9:
|
|
|
|
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
|
|
|
|
|
|
|
return GPTQMarlinFP8Linear
|
|
|
|
|
|
|
|
# On other systems let Torch decide if the hardware supports FP8.
|
|
|
|
return Fp8Linear
|
2024-05-13 10:44:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
|
|
|
device = weight.device
|
|
|
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
|
|
|
finfo = torch.finfo(qdtype)
|
|
|
|
# Calculate the scale as dtype max divided by absmax
|
|
|
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
|
|
|
# scale and clamp the tensor to bring it to
|
|
|
|
# the representative range of float8 data type
|
|
|
|
# (as default cast is unsaturated)
|
|
|
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
|
|
|
# Return both float8 data and the inverse scale (as float),
|
|
|
|
# as both required as inputs to torch._scaled_mm
|
|
|
|
qweight = qweight.to(qdtype)
|
|
|
|
scale = scale.float().reciprocal()
|
|
|
|
return qweight, scale
|
|
|
|
|
|
|
|
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 07:37:39 +00:00
|
|
|
@dataclass
|
|
|
|
class Fp8Weight(Weight):
|
|
|
|
weight: torch.Tensor
|
|
|
|
|
|
|
|
def get_linear(self, bias: torch.Tensor):
|
|
|
|
return get_fp8_linear()(self.weight, bias)
|
|
|
|
|
|
|
|
|
2024-05-13 10:44:30 +00:00
|
|
|
class Fp8Linear(torch.nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
weight,
|
|
|
|
bias,
|
|
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.dtype = weight.dtype
|
|
|
|
self.qweight, self.scale = fp8_quantize(weight)
|
|
|
|
|
|
|
|
self.bias = bias if bias is not None else None
|
|
|
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
qinput, scale = fp8_quantize(input)
|
|
|
|
output, _ = torch._scaled_mm(
|
|
|
|
qinput,
|
|
|
|
self.qweight.t(),
|
|
|
|
out_dtype=self.dtype,
|
|
|
|
scale_a=scale,
|
|
|
|
scale_b=self.scale,
|
|
|
|
bias=self.bias,
|
|
|
|
)
|
|
|
|
return output
|