mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
enable fp8
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8d221b7b79
commit
69773767c5
@ -101,6 +101,7 @@ def serve(
|
|||||||
"bitsandbytes-fp4",
|
"bitsandbytes-fp4",
|
||||||
"gptq",
|
"gptq",
|
||||||
"awq",
|
"awq",
|
||||||
|
"fp8",
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional, Tuple, Type, Union, List
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
Weight,
|
Weight,
|
||||||
@ -10,18 +9,16 @@ from text_generation_server.utils.weights import (
|
|||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
quantization = None
|
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||||
|
from vllm_hpu_extension.ops import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp
|
||||||
|
|
||||||
w8a8_block_fp8_matmul = None
|
w8a8_block_fp8_matmul = None
|
||||||
per_token_group_quant_fp8 = None
|
per_token_group_quant_fp8 = None
|
||||||
|
|
||||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
CUTLASS_FP8_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||||
"""
|
"""
|
||||||
Return an FP8 linear `Module` that is compatible with the current system.
|
Return an FP8 linear `Module` that is compatible with the current system.
|
||||||
@ -43,7 +40,13 @@ def per_tensor_dequantize(
|
|||||||
inv_scale: Union[float, torch.Tensor],
|
inv_scale: Union[float, torch.Tensor],
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
fake_qweight = tensor.to(dtype)
|
device = tensor.device
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
||||||
|
# dequant on cpu to avoid nan on gaudi2
|
||||||
|
tensor = tensor.to("cpu")
|
||||||
|
|
||||||
|
fake_qweight = tensor.to(dtype).to(device)
|
||||||
dq_weight = fake_qweight * inv_scale
|
dq_weight = fake_qweight * inv_scale
|
||||||
return dq_weight
|
return dq_weight
|
||||||
|
|
||||||
@ -55,7 +58,10 @@ def requantize_with_max_scale(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Max scale to be used for requanitzation.
|
# Max scale to be used for requanitzation.
|
||||||
max_w_scale = weight_scale.max().float()
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
|
||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
for idx, logical_width in enumerate(logical_widths):
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
@ -84,9 +90,8 @@ def fp8_quantize(
|
|||||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
be used without modification).
|
be used without modification).
|
||||||
"""
|
"""
|
||||||
if quantization is not None:
|
|
||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = quantization.scaled_fp8_quant(
|
qweight, scale = scaled_fp8_quant(
|
||||||
weight.reshape(-1, shape[-1]),
|
weight.reshape(-1, shape[-1]),
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
@ -96,26 +101,6 @@ def fp8_quantize(
|
|||||||
|
|
||||||
return qweight.reshape(shape), scale
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
finfo = torch.finfo(qdtype)
|
|
||||||
|
|
||||||
if scale is None:
|
|
||||||
# Calculate the scale as dtype max divided by absmax
|
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
|
||||||
# scale and clamp the tensor to bring it to
|
|
||||||
# the representative range of float8 data type
|
|
||||||
# (as default cast is unsaturated)
|
|
||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
scale = scale.float().reciprocal()
|
|
||||||
else:
|
|
||||||
# Use reciprocal to avoid more expensive division.
|
|
||||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
|
|
||||||
# Return both float8 data and the inverse scale (as float),
|
|
||||||
# as both required as inputs to torch._scaled_mm
|
|
||||||
qweight = qweight.to(qdtype)
|
|
||||||
|
|
||||||
return qweight, scale
|
|
||||||
|
|
||||||
|
|
||||||
class HybridFP8UnquantLoader(WeightsLoader):
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
@ -153,6 +138,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.max()
|
.max()
|
||||||
)
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -201,6 +190,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
input_scale = input_scale.reshape(-1).max()
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -259,6 +252,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -296,7 +294,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.max()
|
.max()
|
||||||
)
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -353,27 +354,19 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
weight_block_size: Optional[List[int]] = None,
|
weight_block_size: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale.float()
|
self.scale = scale.float()
|
||||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
self.weight_block_size = weight_block_size
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
|
||||||
self.scale_upper_bound = torch.tensor(
|
|
||||||
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.scale_upper_bound = scale_upper_bound
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=True)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@ -434,14 +427,6 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
return output.to(dtype=input.dtype)
|
return output.to(dtype=input.dtype)
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
|
||||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
|
||||||
qinput, scale = fp8_quantize(
|
|
||||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
|
||||||
)
|
|
||||||
return quantization.cutlass_scaled_mm(
|
|
||||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
|
||||||
)
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input,
|
input,
|
||||||
@ -470,4 +455,4 @@ def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Siz
|
|||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
return scale.reshape(-1).expand(shape[0])
|
return scale.reshape(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user