mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: fp8 dimensions size
fp8 quantization currently limited to tensors with shapes where both dimensions are divisible by 16.
This commit is contained in:
parent
c38a7d7ddd
commit
c07f54aac2
@ -212,6 +212,8 @@ class Fp8Linear(nn.Module):
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if (bsz := input.shape[0]) & 15:
|
||||
input = F.pad(input,(0, 0, 0, 16 - (bsz & 15)))
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
@ -221,7 +223,7 @@ class Fp8Linear(nn.Module):
|
||||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
return output[:bsz]
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user