mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
(feat) convert tscales to tensorwise
This commit is contained in:
parent
b57f370386
commit
e2454dba40
@ -79,6 +79,32 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
return weight, weight_scale, input_scale
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(torch.float16)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
@ -145,13 +171,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
).max()
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -185,7 +213,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0]).max()
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -227,9 +255,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = [
|
scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
|
.max()
|
||||||
|
.unsqueeze(0)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale).to(weights.device)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
|
||||||
|
w, scale = requantize_with_max_scale(w, scale, logical_widths)
|
||||||
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
@ -263,12 +297,14 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
).max()
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
|
Loading…
Reference in New Issue
Block a user