diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 96d5f4a3..d2c46c58 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -203,7 +203,7 @@ class Fp8Linear(torch.nn.Module): @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight) + qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) return cls( qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype )