mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
enable fp8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8d221b7b79
commit
69773767c5
@ -101,6 +101,7 @@ def serve(
|
||||
"bitsandbytes-fp4",
|
||||
"gptq",
|
||||
"awq",
|
||||
"fp8",
|
||||
}:
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Type, Union, List
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
@ -10,18 +9,16 @@ from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
quantization = None
|
||||
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||
from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||
import habana_frameworks.torch.utils.experimental as htexp
|
||||
|
||||
w8a8_block_fp8_matmul = None
|
||||
per_token_group_quant_fp8 = None
|
||||
|
||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||
|
||||
|
||||
CUTLASS_FP8_AVAILABLE = False
|
||||
|
||||
|
||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||
"""
|
||||
Return an FP8 linear `Module` that is compatible with the current system.
|
||||
@ -43,7 +40,13 @@ def per_tensor_dequantize(
|
||||
inv_scale: Union[float, torch.Tensor],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(dtype)
|
||||
device = tensor.device
|
||||
dtype = torch.bfloat16
|
||||
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||
# dequant on cpu to avoid nan on gaudi2
|
||||
tensor = tensor.to("cpu")
|
||||
|
||||
fake_qweight = tensor.to(dtype).to(device)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
@ -55,7 +58,10 @@ def requantize_with_max_scale(
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max().float()
|
||||
max_w_scale = weight_scale.max()
|
||||
|
||||
if is_hpu_gaudi2():
|
||||
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
|
||||
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
@ -84,37 +90,16 @@ def fp8_quantize(
|
||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||
be used without modification).
|
||||
"""
|
||||
if quantization is not None:
|
||||
shape = weight.shape
|
||||
qweight, scale = quantization.scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
use_per_token_if_dynamic=not scalar,
|
||||
)
|
||||
shape = weight.shape
|
||||
qweight, scale = scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
use_per_token_if_dynamic=not scalar,
|
||||
)
|
||||
|
||||
return qweight.reshape(shape), scale
|
||||
|
||||
finfo = torch.finfo(qdtype)
|
||||
|
||||
if scale is None:
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
# Use reciprocal to avoid more expensive division.
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
|
||||
return qweight, scale
|
||||
return qweight.reshape(shape), scale
|
||||
|
||||
|
||||
class HybridFP8UnquantLoader(WeightsLoader):
|
||||
@ -153,6 +138,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
@ -201,6 +190,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
@ -259,6 +252,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -296,7 +294,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -353,27 +354,19 @@ class Fp8Linear(torch.nn.Module):
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||
|
||||
self.dtype = dtype
|
||||
self.qweight = qweight
|
||||
self.scale = scale.float()
|
||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||
self.scale_upper_bound = torch.tensor(
|
||||
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
else:
|
||||
self.scale_upper_bound = scale_upper_bound
|
||||
self.scale_upper_bound = scale_upper_bound
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
||||
qweight, scale = fp8_quantize(weight, scalar=True)
|
||||
return cls(
|
||||
qweight=qweight,
|
||||
scale=scale,
|
||||
@ -434,14 +427,6 @@ class Fp8Linear(torch.nn.Module):
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||
)
|
||||
return quantization.cutlass_scaled_mm(
|
||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||
)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
@ -470,4 +455,4 @@ def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Siz
|
||||
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||
return scale.reshape(-1).expand(shape[0])
|
||||
return scale.reshape(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user