Add support for scalar FP8 weight scales

This commit is contained in:
Daniël de Kok 2024-09-23 15:46:41 +00:00
parent f478aa77ad
commit ccaf9ff507

View File

@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_tensor( scale = (
f"{prefix}.weight_scale", to_dtype=False weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
).reshape(-1) .reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_packed_sharded( scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False if scale.numel() > 1:
).reshape(-1) scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = [ w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
] ]
shapes = [x.shape for x in w]
# Concat then send to the device # Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device) w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = [ scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p in prefixes for p, shape in zip(prefixes, shapes)
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor( scale = (
f"{prefix}.weight_scale", to_dtype=False weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
).reshape(-1) .reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8( return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
) )
@ -256,3 +273,10 @@ class Fp8Linear(torch.nn.Module):
bias=self.bias, bias=self.bias,
) )
return output return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])