mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Add support for scalar FP8 weight scales
This commit is contained in:
parent
f478aa77ad
commit
ccaf9ff507
@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
if scale.numel() > 1:
|
||||||
).reshape(-1)
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
]
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
# Concat then send to the device
|
# Concat then send to the device
|
||||||
w = torch.cat(w, dim=dim).to(weights.device)
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = [
|
scale = [
|
||||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p in prefixes
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
|
|||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
)
|
)
|
||||||
@ -256,3 +273,10 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user