mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-02 23:42:06 +00:00
Fp8 support.
This commit is contained in:
parent
c31cb32dd6
commit
b24bdb9f8c
@ -47,7 +47,6 @@ enum Quantization {
|
||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||
/// perplexity performance for you model
|
||||
BitsandbytesFP4,
|
||||
/// [BETA]
|
||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||
/// This dtype has native ops should be the fastest if available.
|
||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||
|
@ -210,13 +210,8 @@ class Fp8Linear(nn.Module):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
qinput, scale = fp8_quantize(input)
|
||||
seqlen = qinput.shape[0]
|
||||
if seqlen % 16 != 0:
|
||||
missing = 16 - seqlen % 16
|
||||
qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0)
|
||||
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
|
||||
scale_a=scale , scale_b=self.scale, bias=self.bias)
|
||||
output = output[:seqlen]
|
||||
return output
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user