mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fixed scales loading
This commit is contained in:
parent
119918cc0a
commit
74f1f6a702
@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
# FP8 branch
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
)
|
||||
).reshape(-1)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -128,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=0)
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -144,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
|
Loading…
Reference in New Issue
Block a user