fixed scales loading

This commit is contained in:
OlivierDehaene 2024-07-22 13:56:12 +02:00
parent 119918cc0a
commit 74f1f6a702
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
)
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -128,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
scale = torch.cat(scale, dim=0).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -144,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,