mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
(review comments) Fix compression_config load, type hints
This commit is contained in:
parent
f7728565b1
commit
7a7cd5f299
@ -123,12 +123,12 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = weights.get_tensor(
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
except Exception:
|
|
||||||
input_scale = None
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -163,7 +163,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
try:
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.get_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = weights.get_tensor(
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
)
|
)
|
||||||
@ -175,8 +177,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
input_scale = input_scale.reshape(-1).max()
|
input_scale = input_scale.reshape(-1).max()
|
||||||
except Exception:
|
|
||||||
input_scale = None
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -207,14 +207,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
try:
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
input_scale = torch.cat(input_scale, dim=0).reshape(-1).max()
|
]
|
||||||
except Exception:
|
input_scale = (
|
||||||
input_scale = None
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -237,12 +240,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
try:
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = weights.get_tensor(
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
except Exception:
|
|
||||||
input_scale = None
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -272,12 +274,12 @@ class Fp8Weight(Weight):
|
|||||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
self.weight_scale = self.weight_scale.contiguous()
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight,
|
weight=self.weight,
|
||||||
self.weight_scale,
|
scale=self.weight_scale,
|
||||||
self.input_scale,
|
dtype=self.dtype,
|
||||||
self.activation_scale_ub,
|
bias=bias,
|
||||||
bias,
|
input_scale=self.input_scale,
|
||||||
self.dtype,
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -286,12 +288,12 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qweight,
|
qweight: torch.Tensor,
|
||||||
scale,
|
scale: torch.Tensor,
|
||||||
input_scale,
|
dtype: torch.dtype,
|
||||||
scale_upper_bound,
|
bias: Optional[torch.Tensor] = None,
|
||||||
bias,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
dtype,
|
scale_upper_bound: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
@ -327,14 +329,24 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
return cls(
|
return cls(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
input_scale=None,
|
input_scale=None,
|
||||||
scale_upper_bound=None,
|
scale_upper_bound=None,
|
||||||
bias=bias,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, scale_upper_bound, bias, dtype):
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
# fbgemm needs float32 scales.
|
# fbgemm needs float32 scales.
|
||||||
scale = scale.float()
|
scale = scale.float()
|
||||||
@ -391,7 +403,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
else:
|
else:
|
||||||
device_identity = None
|
device_identity = None
|
||||||
@ -405,7 +417,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
scale_b=device_identity,
|
scale_b=device_identity,
|
||||||
out_dtype=torch.float32,
|
out_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
if type(output) is tuple and len(output) == 2:
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
output = output * scale * self.scale.t()
|
output = output * scale * self.scale.t()
|
||||||
|
@ -62,7 +62,14 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, _input_scale, _scale_upper_bound, bias, dtype):
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -342,22 +342,19 @@ def get_model(
|
|||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
compression_config = config_dict.get("compression_config", None)
|
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
|
config_groups = quantization_config.get("config_groups", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
log_master(logger.info, f"Auto selecting quantization method {method}")
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
quantize = method
|
quantize = method
|
||||||
elif method == "fbgemm_fp8" or method == "fp8":
|
elif method == "fbgemm_fp8" or method == "fp8":
|
||||||
log_master(logger.info, "Auto selecting quantization method fp8")
|
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
else:
|
elif config_groups is not None:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
# Compression config renamed to quantization_config
|
||||||
elif compression_config is not None:
|
# TODO: at some point we should probably fully parse the compression
|
||||||
# TODO: at some point we should probably fully parse the compression
|
# configuration to know which parameters are compressed.
|
||||||
# configuration to know which parameters are compressed.
|
|
||||||
config_groups = compression_config.get("config_groups")
|
|
||||||
if config_groups is not None:
|
|
||||||
for _, group in config_groups.items():
|
for _, group in config_groups.items():
|
||||||
weights_config = group.get("weights")
|
weights_config = group.get("weights")
|
||||||
if weights_config is not None:
|
if weights_config is not None:
|
||||||
@ -370,6 +367,8 @@ def get_model(
|
|||||||
)
|
)
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
@ -197,7 +197,7 @@ class Weights:
|
|||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
|
|
||||||
def _has_tensor(self, tensor_name: str):
|
def has_tensor(self, tensor_name: str):
|
||||||
try:
|
try:
|
||||||
self.get_filename(tensor_name)
|
self.get_filename(tensor_name)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
Loading…
Reference in New Issue
Block a user