diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index ed5114ce..bf5a0989 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader): if w.dtype == torch.float8_e4m3fn: # FP8 branch - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch scale = weights.get_packed_sharded( f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False - ) + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -128,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader): weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) for p in prefixes ] - scale = torch.cat(scale, dim=0) + scale = torch.cat(scale, dim=0).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale, @@ -144,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: - scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) return Fp8Weight( weight=w, weight_scale=scale,