mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Add support for scalar FP8 weight scales (#2550)
* Add support for scalar FP8 weight scales * Support LLM compressor FP8 checkpoints on H100 On H100, we use fbgemm-gpu, which requires bfloat16 as the input dtype. However, we wouldn't pick up fp8 quantization for models quantized with LLM compressor. This change adds enough parsing to detect if models have FP8-quantized weights. * Remove stray debug print
This commit is contained in:
parent
68cfc94f40
commit
32d50c2ea7
@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
).reshape(-1)
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = [
|
||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||
for p in prefixes
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||
self.weight_scale = self.weight_scale.contiguous()
|
||||
return get_fp8_linear().from_fp8(
|
||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||
)
|
||||
@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
# fbgemm needs float32 scales.
|
||||
scale = scale.float()
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module):
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||
return scale.reshape(-1).expand(shape[0])
|
||||
|
Loading…
Reference in New Issue
Block a user