diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 86394ff7d..40ee55d76 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -66,6 +66,7 @@ Options: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model + - fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations ``` ## SPECULATE diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 836b03813..cf876fbdd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -47,6 +47,11 @@ enum Quantization { /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model BitsandbytesFP4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, } impl std::fmt::Display for Quantization { @@ -73,6 +78,9 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } } } } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e8b126d92..bb0963d40 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" eetq = "eetq" + fp8 = "fp8" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f29e55c56..2b95bc743 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -182,6 +182,48 @@ class EETQLinear(nn.Module): return output +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class Fp8Linear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.dtype = weight.dtype + self.qweight, self.scale = fp8_quantize(weight) + + self.bias = bias if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qinput, scale = fp8_quantize(input) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -293,6 +335,8 @@ def get_linear(weight, bias, quantize): raise ImportError( "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) + elif quantize == "fp8": + linear = Fp8Linear(weight, bias) elif quantize == "bitsandbytes": warn_deprecate_bnb() linear = Linear8bitLt(