mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix(l4): fix fp8 logic on l4
This commit is contained in:
parent
6aeb669072
commit
3d0c7b85fe
@ -42,8 +42,10 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||
return Fp8Linear
|
||||
|
||||
|
||||
def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
def fp8_quantize(
|
||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||
):
|
||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
@ -232,7 +234,7 @@ class Fp8Linear(torch.nn.Module):
|
||||
)
|
||||
return y.to(self.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(input)
|
||||
qinput, scale = fp8_quantize(input, scalar=True)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
|
Loading…
Reference in New Issue
Block a user