From 935d56abfe1de9402a441a88f5afe21ed77ebf6f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 12 Apr 2024 08:13:30 +0200 Subject: [PATCH] Fp8 Support (#1726) Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: Dong Shin --- docs/source/basic_tutorials/launcher.md | 1 + launcher/src/main.rs | 8 ++++ server/text_generation_server/cli.py | 3 ++ server/text_generation_server/utils/layers.py | 44 +++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 86394ff7d..40ee55d76 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -66,6 +66,7 @@ Options: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model + - fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations ``` ## SPECULATE diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 309e2cb71..f25debe84 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -49,6 +49,11 @@ enum Quantization { /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model BitsandbytesFP4, + /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above + /// This dtype has native ops should be the fastest if available. + /// This is currently not the fastest because of local unpacking + padding to satisfy matrix + /// multiplication limitations. + Fp8, } impl std::fmt::Display for Quantization { @@ -75,6 +80,9 @@ impl std::fmt::Display for Quantization { Quantization::Eetq => { write!(f, "eetq") } + Quantization::Fp8 => { + write!(f, "fp8") + } } } } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 4a308e08a..990c31be5 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -19,6 +19,9 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" + awq = "awq" + eetq = "eetq" + fp8 = "fp8" class Dtype(str, Enum): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f29e55c56..2b95bc743 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -182,6 +182,48 @@ class EETQLinear(nn.Module): return output +def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): + device = weight.device + # weight, scale = quant_weights(weight, torch.int8, False) + finfo = torch.finfo(qdtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / weight.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(qdtype) + scale = scale.float().reciprocal() + return qweight, scale + + +class Fp8Linear(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.dtype = weight.dtype + self.qweight, self.scale = fp8_quantize(weight) + + self.bias = bias if bias is not None else None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qinput, scale = fp8_quantize(input) + output, _ = torch._scaled_mm( + qinput, + self.qweight.t(), + out_dtype=self.dtype, + scale_a=scale, + scale_b=self.scale, + bias=self.bias, + ) + return output + + class Linear8bitLt(nn.Module): def __init__( self, @@ -293,6 +335,8 @@ def get_linear(weight, bias, quantize): raise ImportError( "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) + elif quantize == "fp8": + linear = Fp8Linear(weight, bias) elif quantize == "bitsandbytes": warn_deprecate_bnb() linear = Linear8bitLt(