mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-15 13:52:06 +00:00
fp8 compressed tensors w8a8 support for Gaudi backend (#3242)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
1883a62a94
commit
f14044009a
@ -99,6 +99,7 @@ RUN cd server && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||
RUN pip install compressed-tensors==0.9.1
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
fp8 = "fp8"
|
||||
compressed_tensors = "compressed-tensors"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
@ -109,6 +110,7 @@ def serve(
|
||||
"gptq",
|
||||
"awq",
|
||||
"fp8",
|
||||
"compressed-tensors",
|
||||
}:
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
|
@ -0,0 +1,3 @@
|
||||
from .loader import CompressedTensorsLoader
|
||||
|
||||
__all__ = ["CompressedTensorsLoader"]
|
@ -0,0 +1,169 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from compressed_tensors import QuantizationConfig, QuantizationStatus
|
||||
from compressed_tensors.config import CompressionFormat
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationScheme,
|
||||
QuantizationType,
|
||||
find_name_or_class_matches,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from torch import nn
|
||||
|
||||
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
# compressed-tensors can match modules as quantization targets. However,
|
||||
# they need to be objects rather than classes or class names. Since we
|
||||
# need to match `Linear` targets, make an instance that can be re-used.
|
||||
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
|
||||
|
||||
|
||||
class CompressedTensorsLoader(WeightsLoader):
|
||||
"""Loader for checkpoints stored in the compressed-tensors format."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
quantization_config_raw = config.get("quantization_config")
|
||||
if quantization_config_raw is None:
|
||||
# `compression_config` was renamed to `quantization_config`; support
|
||||
# retained for backward compatibility.
|
||||
quantization_config_raw = config.get("compression_config")
|
||||
if quantization_config_raw is None:
|
||||
raise ValueError(
|
||||
"Checkpoint does not have compressed-tensors configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
quantization_config = QuantizationConfig.model_validate(
|
||||
quantization_config_raw
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
||||
|
||||
if quantization_config.quantization_status not in (
|
||||
QuantizationStatus.COMPRESSED,
|
||||
QuantizationStatus.FROZEN,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
|
||||
)
|
||||
|
||||
self.ignore = (
|
||||
quantization_config.ignore if quantization_config.ignore is not None else []
|
||||
)
|
||||
self.loaders = self._get_target_loaders(quantization_config)
|
||||
|
||||
for target, loader in self.loaders.items():
|
||||
log_once(
|
||||
logger.info,
|
||||
f"Using {loader} for compressed-tensors target '{target}'",
|
||||
)
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights(weights, prefix)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_col_packed(weights, prefix, block_sizes)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
loader = self._lookup_loader(prefixes[0])
|
||||
return loader.get_multi_weights_col(weights, prefixes, dim)
|
||||
|
||||
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
loader = self._lookup_loader(prefixes[0])
|
||||
return loader.get_multi_weights(weights, prefixes, dim)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_row(weights, prefix)
|
||||
|
||||
def _get_target_loaders(
|
||||
self, quantization_config: QuantizationConfig
|
||||
) -> Dict[str, WeightsLoader]:
|
||||
"""
|
||||
A compressed-tensors checkpoint can use different quantizations
|
||||
for different targets. This method returns a dictionary with a
|
||||
loader per target.
|
||||
"""
|
||||
|
||||
loaders: Dict[str, WeightsLoader] = {}
|
||||
|
||||
format = quantization_config.format
|
||||
|
||||
for group_name, group in quantization_config.config_groups.items():
|
||||
# The group configuration can be a string, but does that ever
|
||||
# happen in a serialized quantization config?
|
||||
assert isinstance(group, QuantizationScheme)
|
||||
|
||||
loader = self._create_loader_for_group(format, group_name, group)
|
||||
|
||||
# A quantized parameter group can have multiple targets, add the
|
||||
# loader for all the targets.
|
||||
for target in group.targets:
|
||||
if target in loaders:
|
||||
raise ValueError(
|
||||
f"Target '{target} has multiple configured loaders'"
|
||||
)
|
||||
loaders[target] = loader
|
||||
|
||||
return loaders
|
||||
|
||||
def _create_loader_for_group(
|
||||
self, format: str, group_name: str, group: QuantizationScheme
|
||||
) -> WeightsLoader:
|
||||
"""
|
||||
Find and create a loader for the group with the given quantization
|
||||
scheme.
|
||||
"""
|
||||
# NOTE: we ignore group.output_activations because we don't support
|
||||
# output quantization yet.
|
||||
|
||||
input_activations = group.input_activations
|
||||
weights = group.weights
|
||||
if (
|
||||
format
|
||||
in {
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.naive_quantized.value,
|
||||
}
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.FLOAT
|
||||
and weights.num_bits == 8
|
||||
):
|
||||
# FP W8A8 or W8A16.
|
||||
return W8ANFpLoader(input_activations=input_activations, weights=weights)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
|
||||
)
|
||||
|
||||
def _lookup_loader(self, prefix: str) -> WeightsLoader:
|
||||
"""
|
||||
Look up the loader to use for a given parameter name (prefix).
|
||||
"""
|
||||
|
||||
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
|
||||
# We currently only handle linear layers, so unconditionally pass
|
||||
# a `Linear` instance.
|
||||
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
|
||||
if len(targets) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find compressed-tensors target for prefix: {prefix}"
|
||||
)
|
||||
return self.loaders[targets[0]]
|
@ -0,0 +1,253 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import (
|
||||
Fp8Weight,
|
||||
_load_scalar_or_matrix_scale,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class W8ANFpLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W8A8/W8A16 FP compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_activations: Optional[QuantizationArgs],
|
||||
weights: QuantizationArgs,
|
||||
):
|
||||
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
|
||||
|
||||
# We ignore the `strategy` option which sets the scales to be
|
||||
# per-tensor, per-channel or per-token. What scales are supported
|
||||
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
|
||||
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
|
||||
# So, instead we try to use the best-possible configuration.
|
||||
|
||||
self.load_weight_scale = not weights.dynamic
|
||||
self.load_input_scale = (
|
||||
input_activations is not None and not input_activations.dynamic
|
||||
)
|
||||
self.force_w8a16 = (
|
||||
input_activations is not None and input_activations.num_bits == 16
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
def scale_to_str(scale):
|
||||
return "static" if scale else "dynamic"
|
||||
|
||||
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
logical_widths = [w.shape[0]]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w,
|
||||
weight_scale.unsqueeze(-1).to(weights.device),
|
||||
logical_widths,
|
||||
weights.dtype,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if weight_scale.numel() > 1:
|
||||
weight_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||
logical_widths = [w.shape[0]]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w,
|
||||
weight_scale.unsqueeze(-1).to(weights.device),
|
||||
logical_widths,
|
||||
weights.dtype,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||
if input_scale.numel() > 1:
|
||||
input_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.input_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w,
|
||||
weight_scale.unsqueeze(-1).to(weights.device),
|
||||
logical_widths,
|
||||
weights.dtype,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
weight_scale = None
|
||||
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(shape[0])
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w,
|
||||
weight_scale.unsqueeze(-1).to(weights.device),
|
||||
logical_widths,
|
||||
weights.dtype,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(shape[0])
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||
logical_widths = [w.shape[0]]
|
||||
w, weight_scale = requantize_with_max_scale(
|
||||
w,
|
||||
weight_scale.unsqueeze(-1).to(weights.device),
|
||||
logical_widths,
|
||||
weights.dtype,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
@ -207,7 +207,7 @@ def requantize_with_max_scale(
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
weight[start:end, :], weight_scale[idx], dtype
|
||||
weight[start:end, :], weight_scale[start:end, :], dtype
|
||||
)
|
||||
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||
weight_dq, max_w_scale
|
||||
@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
)
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
scale = [
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(shape[0])
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
input_scale = [
|
||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
||||
for p in prefixes
|
||||
@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
else None
|
||||
)
|
||||
|
||||
logical_widths = [x[0] for x in shapes]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
logical_widths = [w.shape[0]]
|
||||
w, scale = requantize_with_max_scale(
|
||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module):
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_shared_device_identity(cls, device):
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
if device not in cls._device_identity_cache:
|
||||
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||
return cls._device_identity_cache[device]
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
if self.weight_block_size is not None or self.input_scale is None:
|
||||
return apply_block_fp8_linear_hpu_dynamic(
|
||||
input, self.qweight, self.scale, self.input_scale, self.bias
|
||||
)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
self.input_scale,
|
||||
scale_upper_bound=self.scale_upper_bound,
|
||||
scalar=True,
|
||||
)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
out_dtype=self.dtype,
|
||||
scale_a=scale,
|
||||
scale_b=self.scale,
|
||||
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
|
||||
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
return torch.ops.hpu.fp8_gemm_v2(
|
||||
A=x_fp8,
|
||||
trans_A=False,
|
||||
B=self.qweight,
|
||||
trans_B=True,
|
||||
D=None,
|
||||
out_dtype=input.dtype,
|
||||
A_scale_inv=self.input_scale,
|
||||
B_scale_inv=self.scale,
|
||||
bias=self.bias,
|
||||
accumulate=False,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple) and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||
return scale.reshape(-1)
|
||||
return scale.reshape(-1).expand(shape[0])
|
||||
|
@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
|
||||
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
|
||||
|
||||
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_awq_kernel=self.quantize == "awq",
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
self._get_gptq_params(weights)
|
||||
|
||||
|
@ -1412,7 +1412,6 @@ class FlashCausalLM(Model):
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
print(f"weights: {weights}")
|
||||
|
||||
prefix = None
|
||||
model = model_class(prefix, config, weights)
|
||||
|
@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision):
|
||||
def get_loader(
|
||||
quantize: Optional[str], model_id: str, revision: Optional[str]
|
||||
) -> WeightsLoader:
|
||||
if quantize == "compressed-tensors":
|
||||
config = _get_config_json(model_id, revision, "config.json")
|
||||
from text_generation_server.layers.compressed_tensors import (
|
||||
CompressedTensorsLoader,
|
||||
)
|
||||
|
||||
return CompressedTensorsLoader(config)
|
||||
quantizer_config = _get_quantizer_config(model_id, revision)
|
||||
if quantize in {"awq", "gptq"}:
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
|
@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator {
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut tokens = tokens;
|
||||
if self.is_hpu_device {
|
||||
// need 1 slot for ping-pong optimization
|
||||
tokens += 1;
|
||||
}
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator {
|
||||
let required_blocks = tokens.div_ceil(self.block_size);
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let mut tokens = tokens as usize;
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator {
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
if self.is_hpu_device {
|
||||
blocks.sort();
|
||||
// need 1 slot for ping-pong optimization
|
||||
tokens += 1;
|
||||
}
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
Loading…
Reference in New Issue
Block a user