Merge branch 'main' into qwen3_moe

This commit is contained in:
Yuan Wu 2025-05-29 13:05:31 +08:00 committed by GitHub
commit 5155fef477
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 581 additions and 162 deletions

View File

@ -99,6 +99,7 @@ RUN cd server && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
fp8 = "fp8" fp8 = "fp8"
compressed_tensors = "compressed-tensors"
class Dtype(str, Enum): class Dtype(str, Enum):
@ -109,6 +110,7 @@ def serve(
"gptq", "gptq",
"awq", "awq",
"fp8", "fp8",
"compressed-tensors",
}: }:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."

View File

@ -0,0 +1,3 @@
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]

View File

@ -0,0 +1,169 @@
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]

View File

@ -0,0 +1,253 @@
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)

View File

@ -207,7 +207,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
end = start + logical_width end = start + logical_width
weight_dq = per_tensor_dequantize( weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype weight[start:end, :], weight_scale[start:end, :], dtype
) )
weight[start:end, :], max_w_scale_normalized = fp8_quantize( weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale weight_dq, max_w_scale
@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
) )
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.max() .max()
) )
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False, to_dtype=False,
) )
input_scale = input_scale.reshape(-1).max() input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [ input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None else None
) )
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
) )
scale = [ scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
for p in prefixes .reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [ input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes for p in prefixes
@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None else None
) )
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
weight_block_size=self.weight_block_size, weight_block_size=self.weight_block_size,
) )
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1) .reshape(-1)
.max() .max()
) )
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module):
weight_block_size=weight_block_size, weight_block_size=weight_block_size,
) )
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None or self.input_scale is None:
return apply_block_fp8_linear_hpu_dynamic( return apply_block_fp8_linear_hpu_dynamic(
input, self.qweight, self.scale, self.input_scale, self.bias input, self.qweight, self.scale, self.input_scale, self.bias
) )
qinput, scale = fp8_quantize( x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
input, input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
self.input_scale, )[0]
scale_upper_bound=self.scale_upper_bound, return torch.ops.hpu.fp8_gemm_v2(
scalar=True, A=x_fp8,
) trans_A=False,
B=self.qweight,
output = torch._scaled_mm( trans_B=True,
qinput, D=None,
self.qweight.t(), out_dtype=input.dtype,
out_dtype=self.dtype, A_scale_inv=self.input_scale,
scale_a=scale, B_scale_inv=self.scale,
scale_b=self.scale,
bias=self.bias, bias=self.bias,
accumulate=False,
) )
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False) scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False) scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1) return scale.reshape(-1).expand(shape[0])

View File

@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama=use_exllama, use_exllama=use_exllama,
) )
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
self._get_gptq_params(weights)
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights) self._get_gptq_params(weights)

View File

@ -309,7 +309,6 @@ class ModelType(enum.Enum):
"name": "Qwen 3 Moe", "name": "Qwen 3 Moe",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
} }
GALACTICA = { GALACTICA = {
"type": "galactica", "type": "galactica",
"name": "Galactica", "name": "Galactica",
@ -832,7 +831,6 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif model_type == MLLAMA: elif model_type == MLLAMA:
return FlashMllamaCausalLM( return FlashMllamaCausalLM(
model_id=model_id, model_id=model_id,

View File

@ -22,6 +22,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import habana_frameworks.torch as htorch
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module):
) )
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer( hidden_states = layer(
@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module):
position_ids=position_ids, position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)

View File

@ -14,6 +14,7 @@ from typing import Optional, Tuple, List
import torch import torch
from torch import nn from torch import nn
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -26,14 +27,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
SpeculativeHead, SpeculativeHead,
FastLinear,
) )
from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from loguru import logger
from text_generation_server.utils.log import log_master
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -53,12 +47,9 @@ class Qwen3Attention(nn.Module):
self.head_dim = getattr( self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads config, "head_dim", config.hidden_size // config.num_attention_heads
) )
config.num_key_value_heads = getattr( self.num_key_value_groups = (
config, "num_key_value_heads", config.num_attention_heads config.num_attention_heads // config.num_key_value_heads
) )
# self.num_key_value_groups = (
# config.num_attention_heads // config.num_key_value_heads
# )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
@ -75,65 +66,16 @@ class Qwen3Attention(nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
# self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = (
if config.num_key_value_heads > weights.process_group.size(): config.num_key_value_heads // weights.process_group.size()
self.num_key_value_heads = ( )
config.num_key_value_heads // weights.process_group.size() self.query_key_value = TensorParallelColumnLinear.load_multi(
)
else:
self.num_key_value_heads = config.num_key_value_heads
self.query_proj = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.q_proj", prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights, weights=weights,
bias=False, bias=False,
) )
if self.num_key_value_heads != config.num_key_value_heads:
self.key_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.value_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
else:
self.key_proj = FastLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.value_proj = FastLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
# self.key_proj = TensorParallelColumnLinear.load(
# config,
# prefix=f"{prefix}.k_proj",
# weights=weights,
# bias=False,
# )
# self.value_proj = TensorParallelColumnLinear.load(
# config,
# prefix=f"{prefix}.v_proj",
# weights=weights,
# bias=False,
# )
# self.query_key_value = TensorParallelColumnLinear.load_multi(
# config,
# prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
# dim=0,
# weights=weights,
# bias=False,
# )
self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_scales = get_kv_scales(weights, f"{prefix}")
@ -144,10 +86,11 @@ class Qwen3Attention(nn.Module):
bias=False, bias=False,
) )
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_key_value_groups) ).repeat_interleave(self.num_groups)
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window if config.sliding_window is not None else -1
@ -182,45 +125,21 @@ class Qwen3Attention(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
print(f"hidden_states shape: {hidden_states.shape}")
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
# qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
# print(f"qkv shape: {qkv.shape}") query_states, key_states, value_states = qkv.split(
# print(f"self.head_dim: {self.head_dim}") [
# print(f"self.num_heads: {self.num_heads}") self.head_dim * self.num_heads,
# print(f"self.num_key_value_heads: {self.num_key_value_heads}") self.head_dim * self.num_key_value_heads,
# query_states, key_states, value_states = qkv.split( self.head_dim * self.num_key_value_heads,
# [ ],
# self.head_dim * self.num_heads, dim=1,
# self.head_dim * self.num_key_value_heads,
# self.head_dim * self.num_key_value_heads,
# ],
# dim=1,
# )
synchronize(hidden_states.device)
real_free_memory = get_free_memory(hidden_states.device, 1)
log_master(
logger.debug,
f"Attention forward1 Free memory real: {real_free_memory / 1e9:.2f}GB",
) )
query_states = self.query_proj(hidden_states)
key_states = self.key_proj(hidden_states)
value_states = self.value_proj(hidden_states)
query_states, _ = self.q_norm(query_states.view(hidden_shape)) query_states, _ = self.q_norm(query_states.view(hidden_shape))
key_states, _ = self.k_norm(key_states.view(hidden_shape)) key_states, _ = self.k_norm(key_states.view(hidden_shape))
value_states = value_states.view(hidden_shape) value_states = value_states.view(hidden_shape)
print(f"query_states shape: {query_states.shape}")
print(f"key_states shape: {key_states.shape}")
print(f"value_states shape: {value_states.shape}")
synchronize(hidden_states.device)
real_free_memory = get_free_memory(hidden_states.device, 1)
log_master(
logger.debug,
f"Attention forward2 Free memory real: {real_free_memory / 1e9:.2f}GB",
)
self.rotary_emb(query_states, key_states, cos, sin) self.rotary_emb(query_states, key_states, cos, sin)
kv_cache.store( kv_cache.store(
@ -257,7 +176,6 @@ class Qwen3Attention(nn.Module):
) )
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
print(f"attn_output shape: {attn_output.shape}")
return self.o_proj(attn_output) return self.o_proj(attn_output)
@ -359,6 +277,11 @@ class Qwen3Model(nn.Module):
) )
residual = None residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, decoder_layer in enumerate(self.layers): for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states, hidden_states,
@ -371,6 +294,8 @@ class Qwen3Model(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) )
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)

View File

@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision):
def get_loader( def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str] quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader: ) -> WeightsLoader:
if quantize == "compressed-tensors":
config = _get_config_json(model_id, revision, "config.json")
from text_generation_server.layers.compressed_tensors import (
CompressedTensorsLoader,
)
return CompressedTensorsLoader(config)
quantizer_config = _get_quantizer_config(model_id, revision) quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}: if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader

View File

@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator {
tokens: u32, tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>, _prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> { ) -> Option<BlockAllocation> {
let mut tokens = tokens;
if self.is_hpu_device {
// need 1 slot for ping-pong optimization
tokens += 1;
}
// Apply window size // Apply window size
let (required_blocks, repeats) = { let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size { let (tokens, repeats) = match self.window_size {
@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator {
let required_blocks = tokens.div_ceil(self.block_size); let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats) (required_blocks, repeats)
}; };
let tokens = tokens as usize;
let mut tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 { if required_blocks > self.free_blocks.len() as u32 {
None None
} else { } else {
@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator {
.split_off(self.free_blocks.len() - required_blocks as usize); .split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device { if self.is_hpu_device {
blocks.sort(); blocks.sort();
// need 1 slot for ping-pong optimization
tokens += 1;
} }
let mut slots = let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);