Add support for scalar FP8 weight scales

This commit is contained in:
Daniël de Kok 2024-09-23 15:46:41 +00:00
parent f478aa77ad
commit ccaf9ff507

View File

@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
).reshape(-1)
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
@ -256,3 +273,10 @@ class Fp8Linear(torch.nn.Module):
bias=self.bias,
)
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])