mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 06:12:07 +00:00
fp8 compressed tensors w8a8 support for Gaudi backend (#3242)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
1883a62a94
commit
f14044009a
@ -99,6 +99,7 @@ RUN cd server && \
|
|||||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
|
||||||
|
RUN pip install compressed-tensors==0.9.1
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
|||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
awq = "awq"
|
awq = "awq"
|
||||||
fp8 = "fp8"
|
fp8 = "fp8"
|
||||||
|
compressed_tensors = "compressed-tensors"
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
@ -109,6 +110,7 @@ def serve(
|
|||||||
"gptq",
|
"gptq",
|
||||||
"awq",
|
"awq",
|
||||||
"fp8",
|
"fp8",
|
||||||
|
"compressed-tensors",
|
||||||
}:
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from .loader import CompressedTensorsLoader
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsLoader"]
|
@ -0,0 +1,169 @@
|
|||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from compressed_tensors import QuantizationConfig, QuantizationStatus
|
||||||
|
from compressed_tensors.config import CompressionFormat
|
||||||
|
from compressed_tensors.quantization import (
|
||||||
|
QuantizationScheme,
|
||||||
|
QuantizationType,
|
||||||
|
find_name_or_class_matches,
|
||||||
|
)
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
WeightsLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compressed-tensors can match modules as quantization targets. However,
|
||||||
|
# they need to be objects rather than classes or class names. Since we
|
||||||
|
# need to match `Linear` targets, make an instance that can be re-used.
|
||||||
|
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsLoader(WeightsLoader):
|
||||||
|
"""Loader for checkpoints stored in the compressed-tensors format."""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
quantization_config_raw = config.get("quantization_config")
|
||||||
|
if quantization_config_raw is None:
|
||||||
|
# `compression_config` was renamed to `quantization_config`; support
|
||||||
|
# retained for backward compatibility.
|
||||||
|
quantization_config_raw = config.get("compression_config")
|
||||||
|
if quantization_config_raw is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Checkpoint does not have compressed-tensors configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
quantization_config = QuantizationConfig.model_validate(
|
||||||
|
quantization_config_raw
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
||||||
|
|
||||||
|
if quantization_config.quantization_status not in (
|
||||||
|
QuantizationStatus.COMPRESSED,
|
||||||
|
QuantizationStatus.FROZEN,
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ignore = (
|
||||||
|
quantization_config.ignore if quantization_config.ignore is not None else []
|
||||||
|
)
|
||||||
|
self.loaders = self._get_target_loaders(quantization_config)
|
||||||
|
|
||||||
|
for target, loader in self.loaders.items():
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
f"Using {loader} for compressed-tensors target '{target}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
loader = self._lookup_loader(prefix)
|
||||||
|
return loader.get_weights(weights, prefix)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: "Weights",
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
loader = self._lookup_loader(prefix)
|
||||||
|
return loader.get_weights_col_packed(weights, prefix, block_sizes)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
loader = self._lookup_loader(prefixes[0])
|
||||||
|
return loader.get_multi_weights_col(weights, prefixes, dim)
|
||||||
|
|
||||||
|
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
loader = self._lookup_loader(prefixes[0])
|
||||||
|
return loader.get_multi_weights(weights, prefixes, dim)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
loader = self._lookup_loader(prefix)
|
||||||
|
return loader.get_weights_row(weights, prefix)
|
||||||
|
|
||||||
|
def _get_target_loaders(
|
||||||
|
self, quantization_config: QuantizationConfig
|
||||||
|
) -> Dict[str, WeightsLoader]:
|
||||||
|
"""
|
||||||
|
A compressed-tensors checkpoint can use different quantizations
|
||||||
|
for different targets. This method returns a dictionary with a
|
||||||
|
loader per target.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loaders: Dict[str, WeightsLoader] = {}
|
||||||
|
|
||||||
|
format = quantization_config.format
|
||||||
|
|
||||||
|
for group_name, group in quantization_config.config_groups.items():
|
||||||
|
# The group configuration can be a string, but does that ever
|
||||||
|
# happen in a serialized quantization config?
|
||||||
|
assert isinstance(group, QuantizationScheme)
|
||||||
|
|
||||||
|
loader = self._create_loader_for_group(format, group_name, group)
|
||||||
|
|
||||||
|
# A quantized parameter group can have multiple targets, add the
|
||||||
|
# loader for all the targets.
|
||||||
|
for target in group.targets:
|
||||||
|
if target in loaders:
|
||||||
|
raise ValueError(
|
||||||
|
f"Target '{target} has multiple configured loaders'"
|
||||||
|
)
|
||||||
|
loaders[target] = loader
|
||||||
|
|
||||||
|
return loaders
|
||||||
|
|
||||||
|
def _create_loader_for_group(
|
||||||
|
self, format: str, group_name: str, group: QuantizationScheme
|
||||||
|
) -> WeightsLoader:
|
||||||
|
"""
|
||||||
|
Find and create a loader for the group with the given quantization
|
||||||
|
scheme.
|
||||||
|
"""
|
||||||
|
# NOTE: we ignore group.output_activations because we don't support
|
||||||
|
# output quantization yet.
|
||||||
|
|
||||||
|
input_activations = group.input_activations
|
||||||
|
weights = group.weights
|
||||||
|
if (
|
||||||
|
format
|
||||||
|
in {
|
||||||
|
CompressionFormat.float_quantized.value,
|
||||||
|
CompressionFormat.naive_quantized.value,
|
||||||
|
}
|
||||||
|
and weights is not None
|
||||||
|
and weights.type == QuantizationType.FLOAT
|
||||||
|
and weights.num_bits == 8
|
||||||
|
):
|
||||||
|
# FP W8A8 or W8A16.
|
||||||
|
return W8ANFpLoader(input_activations=input_activations, weights=weights)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _lookup_loader(self, prefix: str) -> WeightsLoader:
|
||||||
|
"""
|
||||||
|
Look up the loader to use for a given parameter name (prefix).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
|
||||||
|
return DefaultWeightsLoader(UnquantizedWeight)
|
||||||
|
|
||||||
|
# We currently only handle linear layers, so unconditionally pass
|
||||||
|
# a `Linear` instance.
|
||||||
|
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
|
||||||
|
if len(targets) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find compressed-tensors target for prefix: {prefix}"
|
||||||
|
)
|
||||||
|
return self.loaders[targets[0]]
|
@ -0,0 +1,253 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
_load_scalar_or_matrix_scale,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
class W8ANFpLoader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for W8A8/W8A16 FP compressed-tensors parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input_activations: Optional[QuantizationArgs],
|
||||||
|
weights: QuantizationArgs,
|
||||||
|
):
|
||||||
|
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
|
||||||
|
|
||||||
|
# We ignore the `strategy` option which sets the scales to be
|
||||||
|
# per-tensor, per-channel or per-token. What scales are supported
|
||||||
|
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
|
||||||
|
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
|
||||||
|
# So, instead we try to use the best-possible configuration.
|
||||||
|
|
||||||
|
self.load_weight_scale = not weights.dynamic
|
||||||
|
self.load_input_scale = (
|
||||||
|
input_activations is not None and not input_activations.dynamic
|
||||||
|
)
|
||||||
|
self.force_w8a16 = (
|
||||||
|
input_activations is not None and input_activations.num_bits == 16
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
def scale_to_str(scale):
|
||||||
|
return "static" if scale else "dynamic"
|
||||||
|
|
||||||
|
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
|
||||||
|
|
||||||
|
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
weight_scale = None
|
||||||
|
if self.load_weight_scale:
|
||||||
|
weight_scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w,
|
||||||
|
weight_scale.unsqueeze(-1).to(weights.device),
|
||||||
|
logical_widths,
|
||||||
|
weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if self.load_input_scale:
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
force_w8a16=self.force_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_scale = None
|
||||||
|
if self.load_weight_scale:
|
||||||
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
if weight_scale.numel() > 1:
|
||||||
|
weight_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w,
|
||||||
|
weight_scale.unsqueeze(-1).to(weights.device),
|
||||||
|
logical_widths,
|
||||||
|
weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if self.load_input_scale:
|
||||||
|
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
force_w8a16=self.force_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
|
weight_scale = None
|
||||||
|
if self.load_weight_scale:
|
||||||
|
weight_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
]
|
||||||
|
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w,
|
||||||
|
weight_scale.unsqueeze(-1).to(weights.device),
|
||||||
|
logical_widths,
|
||||||
|
weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if self.load_input_scale:
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
force_w8a16=self.force_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
|
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
|
weight_scale = None
|
||||||
|
|
||||||
|
if self.load_weight_scale:
|
||||||
|
weight_scale = [
|
||||||
|
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.expand(shape[0])
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
]
|
||||||
|
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w,
|
||||||
|
weight_scale.unsqueeze(-1).to(weights.device),
|
||||||
|
logical_widths,
|
||||||
|
weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if self.load_input_scale:
|
||||||
|
input_scale = [
|
||||||
|
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.expand(shape[0])
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
force_w8a16=self.force_w8a16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
weight_scale = None
|
||||||
|
if self.load_weight_scale:
|
||||||
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w,
|
||||||
|
weight_scale.unsqueeze(-1).to(weights.device),
|
||||||
|
logical_widths,
|
||||||
|
weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if self.load_input_scale:
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
force_w8a16=self.force_w8a16,
|
||||||
|
)
|
@ -207,7 +207,7 @@ def requantize_with_max_scale(
|
|||||||
for idx, logical_width in enumerate(logical_widths):
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
end = start + logical_width
|
end = start + logical_width
|
||||||
weight_dq = per_tensor_dequantize(
|
weight_dq = per_tensor_dequantize(
|
||||||
weight[start:end, :], weight_scale[idx], dtype
|
weight[start:end, :], weight_scale[start:end, :], dtype
|
||||||
)
|
)
|
||||||
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
weight_dq, max_w_scale
|
weight_dq, max_w_scale
|
||||||
@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.max()
|
.max()
|
||||||
)
|
)
|
||||||
logical_widths = [w.shape[0]]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
input_scale = input_scale.reshape(-1).max()
|
input_scale = input_scale.reshape(-1).max()
|
||||||
logical_widths = [w.shape[0]]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
logical_widths = [x[0] for x in shapes]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
scale = [
|
scale = [
|
||||||
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
|
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
|
||||||
for p in prefixes
|
.reshape(-1)
|
||||||
|
.expand(shape[0])
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
|
||||||
for p in prefixes
|
for p in prefixes
|
||||||
@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
logical_widths = [x[0] for x in shapes]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.to(weights.device), logical_widths, weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
weight_block_size=self.weight_block_size,
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = (
|
||||||
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
|
logical_widths = [w.shape[0]]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.max()
|
.max()
|
||||||
)
|
)
|
||||||
logical_widths = [w.shape[0]]
|
|
||||||
w, scale = requantize_with_max_scale(
|
|
||||||
w, scale.unsqueeze(0), logical_widths, weights.dtype
|
|
||||||
)
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
weight_block_size=weight_block_size,
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_shared_device_identity(cls, device):
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
|
||||||
if device not in cls._device_identity_cache:
|
|
||||||
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
|
||||||
return cls._device_identity_cache[device]
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.weight_block_size is not None:
|
if self.weight_block_size is not None or self.input_scale is None:
|
||||||
return apply_block_fp8_linear_hpu_dynamic(
|
return apply_block_fp8_linear_hpu_dynamic(
|
||||||
input, self.qweight, self.scale, self.input_scale, self.bias
|
input, self.qweight, self.scale, self.input_scale, self.bias
|
||||||
)
|
)
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
input,
|
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
|
||||||
self.input_scale,
|
)[0]
|
||||||
scale_upper_bound=self.scale_upper_bound,
|
return torch.ops.hpu.fp8_gemm_v2(
|
||||||
scalar=True,
|
A=x_fp8,
|
||||||
)
|
trans_A=False,
|
||||||
|
B=self.qweight,
|
||||||
output = torch._scaled_mm(
|
trans_B=True,
|
||||||
qinput,
|
D=None,
|
||||||
self.qweight.t(),
|
out_dtype=input.dtype,
|
||||||
out_dtype=self.dtype,
|
A_scale_inv=self.input_scale,
|
||||||
scale_a=scale,
|
B_scale_inv=self.scale,
|
||||||
scale_b=self.scale,
|
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
|
accumulate=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(output, tuple) and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
return scale.reshape(-1)
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
use_exllama=use_exllama,
|
use_exllama=use_exllama,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
|
||||||
|
|
||||||
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
|
||||||
|
|
||||||
|
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
|
||||||
|
|
||||||
|
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||||
|
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // self.bits),
|
||||||
|
device=qweight.device,
|
||||||
|
)
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
|
return GPTQWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=self.bits,
|
||||||
|
groupsize=self.groupsize,
|
||||||
|
use_awq_kernel=self.quantize == "awq",
|
||||||
|
use_exllama=use_exllama,
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_row(self, weights: Weights, prefix: str):
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
|
@ -1412,7 +1412,6 @@ class FlashCausalLM(Model):
|
|||||||
aliases=aliases,
|
aliases=aliases,
|
||||||
weights_loader=weights_loader,
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
print(f"weights: {weights}")
|
|
||||||
|
|
||||||
prefix = None
|
prefix = None
|
||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
|
@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
def get_loader(
|
def get_loader(
|
||||||
quantize: Optional[str], model_id: str, revision: Optional[str]
|
quantize: Optional[str], model_id: str, revision: Optional[str]
|
||||||
) -> WeightsLoader:
|
) -> WeightsLoader:
|
||||||
|
if quantize == "compressed-tensors":
|
||||||
|
config = _get_config_json(model_id, revision, "config.json")
|
||||||
|
from text_generation_server.layers.compressed_tensors import (
|
||||||
|
CompressedTensorsLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompressedTensorsLoader(config)
|
||||||
quantizer_config = _get_quantizer_config(model_id, revision)
|
quantizer_config = _get_quantizer_config(model_id, revision)
|
||||||
if quantize in {"awq", "gptq"}:
|
if quantize in {"awq", "gptq"}:
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
|
@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator {
|
|||||||
tokens: u32,
|
tokens: u32,
|
||||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
) -> Option<BlockAllocation> {
|
) -> Option<BlockAllocation> {
|
||||||
|
let mut tokens = tokens;
|
||||||
|
if self.is_hpu_device {
|
||||||
|
// need 1 slot for ping-pong optimization
|
||||||
|
tokens += 1;
|
||||||
|
}
|
||||||
// Apply window size
|
// Apply window size
|
||||||
let (required_blocks, repeats) = {
|
let (required_blocks, repeats) = {
|
||||||
let (tokens, repeats) = match self.window_size {
|
let (tokens, repeats) = match self.window_size {
|
||||||
@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator {
|
|||||||
let required_blocks = tokens.div_ceil(self.block_size);
|
let required_blocks = tokens.div_ceil(self.block_size);
|
||||||
(required_blocks, repeats)
|
(required_blocks, repeats)
|
||||||
};
|
};
|
||||||
|
let tokens = tokens as usize;
|
||||||
let mut tokens = tokens as usize;
|
|
||||||
if required_blocks > self.free_blocks.len() as u32 {
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator {
|
|||||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
if self.is_hpu_device {
|
if self.is_hpu_device {
|
||||||
blocks.sort();
|
blocks.sort();
|
||||||
// need 1 slot for ping-pong optimization
|
|
||||||
tokens += 1;
|
|
||||||
}
|
}
|
||||||
let mut slots =
|
let mut slots =
|
||||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
Loading…
Reference in New Issue
Block a user