mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: fp8 dimensions size
fp8 quantization currently limited to tensors with shapes where both dimensions are divisible by 16.
This commit is contained in:
parent
c38a7d7ddd
commit
c07f54aac2
@ -212,6 +212,8 @@ class Fp8Linear(nn.Module):
|
|||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if (bsz := input.shape[0]) & 15:
|
||||||
|
input = F.pad(input,(0, 0, 0, 16 - (bsz & 15)))
|
||||||
qinput, scale = fp8_quantize(input)
|
qinput, scale = fp8_quantize(input)
|
||||||
output, _ = torch._scaled_mm(
|
output, _ = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
@ -221,7 +223,7 @@ class Fp8Linear(nn.Module):
|
|||||||
scale_b=self.scale,
|
scale_b=self.scale,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
return output
|
return output[:bsz]
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user