mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
Initial fp8.
This commit is contained in:
parent
30620a9a44
commit
50d5a3c11e
@ -47,6 +47,9 @@ enum Quantization {
|
|||||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||||
/// perplexity performance for you model
|
/// perplexity performance for you model
|
||||||
BitsandbytesFP4,
|
BitsandbytesFP4,
|
||||||
|
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||||
|
/// This dtype has native ops should be the fastest if available.
|
||||||
|
Fp8,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -73,6 +76,9 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Eetq => {
|
Quantization::Eetq => {
|
||||||
write!(f, "eetq")
|
write!(f, "eetq")
|
||||||
}
|
}
|
||||||
|
Quantization::Fp8 => {
|
||||||
|
write!(f, "fp8")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
|||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
awq = "awq"
|
awq = "awq"
|
||||||
eetq = "eetq"
|
eetq = "eetq"
|
||||||
|
fp8 = "fp8"
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
|
@ -181,6 +181,40 @@ class EETQLinear(nn.Module):
|
|||||||
output = output + self.bias if self.bias is not None else output
|
output = output + self.bias if self.bias is not None else output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class Fp8Linear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
device = weight.device
|
||||||
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
|
finfo = torch.finfo(weight.dtype)
|
||||||
|
qdtype = torch.float8_e4m3fn
|
||||||
|
# Calculate the scale as dtype max divided by absmax
|
||||||
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||||
|
# scale and clamp the tensor to bring it to
|
||||||
|
# the representative range of float8 data type
|
||||||
|
# (as default cast is unsaturated)
|
||||||
|
x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
# Return both float8 data and the inverse scale (as float),
|
||||||
|
# as both required as inputs to torch._scaled_mm
|
||||||
|
self.dtype = weight.dtype
|
||||||
|
self.qweight = x_scl_sat.to(qdtype).to(device=device)
|
||||||
|
self.scale = scale.float().reciprocal().to(device=device)
|
||||||
|
self.bias = bias.cuda(device) if bias is not None else None
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
finfo = torch.finfo(input.dtype)
|
||||||
|
scale = finfo.max / input.abs().max().clamp(min=1e-12)
|
||||||
|
qinput = (input * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
|
output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16,
|
||||||
|
scale_a=scale , scale_b=self.scale)
|
||||||
|
output = output + self.bias if self.bias is not None else output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -293,6 +327,12 @@ def get_linear(weight, bias, quantize):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||||
)
|
)
|
||||||
|
elif quantize == "fp8":
|
||||||
|
linear = Fp8Linear(weight, bias)
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||||
|
)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
warn_deprecate_bnb()
|
warn_deprecate_bnb()
|
||||||
linear = Linear8bitLt(
|
linear = Linear8bitLt(
|
||||||
|
Loading…
Reference in New Issue
Block a user