mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 16:22:06 +00:00
Fp8 support.
This commit is contained in:
parent
c31cb32dd6
commit
b24bdb9f8c
@ -47,7 +47,6 @@ enum Quantization {
|
|||||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||||
/// perplexity performance for you model
|
/// perplexity performance for you model
|
||||||
BitsandbytesFP4,
|
BitsandbytesFP4,
|
||||||
/// [BETA]
|
|
||||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||||
/// This dtype has native ops should be the fastest if available.
|
/// This dtype has native ops should be the fastest if available.
|
||||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||||
|
@ -210,13 +210,8 @@ class Fp8Linear(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
qinput, scale = fp8_quantize(input)
|
qinput, scale = fp8_quantize(input)
|
||||||
seqlen = qinput.shape[0]
|
|
||||||
if seqlen % 16 != 0:
|
|
||||||
missing = 16 - seqlen % 16
|
|
||||||
qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0)
|
|
||||||
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
|
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
|
||||||
scale_a=scale , scale_b=self.scale, bias=self.bias)
|
scale_a=scale , scale_b=self.scale, bias=self.bias)
|
||||||
output = output[:seqlen]
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user