diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index e7e7031b..4bda27cf 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -99,6 +99,7 @@ RUN cd server && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix +RUN pip install compressed-tensors==0.9.1 # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 59e8a96f..8fa2b263 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" fp8 = "fp8" + compressed_tensors = "compressed-tensors" class Dtype(str, Enum): @@ -109,6 +110,7 @@ def serve( "gptq", "awq", "fp8", + "compressed-tensors", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py new file mode 100644 index 00000000..507af706 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py @@ -0,0 +1,3 @@ +from .loader import CompressedTensorsLoader + +__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py new file mode 100644 index 00000000..0dccf34a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py @@ -0,0 +1,169 @@ +from typing import Any, Dict, List, Union + +from compressed_tensors import QuantizationConfig, QuantizationStatus +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationType, + find_name_or_class_matches, +) +from loguru import logger +from pydantic import ValidationError +from torch import nn + +from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, + WeightsLoader, +) + +# compressed-tensors can match modules as quantization targets. However, +# they need to be objects rather than classes or class names. Since we +# need to match `Linear` targets, make an instance that can be re-used. +_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) + + +class CompressedTensorsLoader(WeightsLoader): + """Loader for checkpoints stored in the compressed-tensors format.""" + + def __init__(self, config: Dict[str, Any]): + quantization_config_raw = config.get("quantization_config") + if quantization_config_raw is None: + # `compression_config` was renamed to `quantization_config`; support + # retained for backward compatibility. + quantization_config_raw = config.get("compression_config") + if quantization_config_raw is None: + raise ValueError( + "Checkpoint does not have compressed-tensors configuration" + ) + + try: + quantization_config = QuantizationConfig.model_validate( + quantization_config_raw + ) + except ValidationError as e: + raise ValueError("Cannot parse compressed-tensors configuration") from e + + if quantization_config.quantization_status not in ( + QuantizationStatus.COMPRESSED, + QuantizationStatus.FROZEN, + ): + raise ValueError( + f"Model quantization was not finished, status was: {quantization_config.quantization_status}" + ) + + self.ignore = ( + quantization_config.ignore if quantization_config.ignore is not None else [] + ) + self.loaders = self._get_target_loaders(quantization_config) + + for target, loader in self.loaders.items(): + log_once( + logger.info, + f"Using {loader} for compressed-tensors target '{target}'", + ) + + def get_weights(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights(weights, prefix) + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + loader = self._lookup_loader(prefix) + return loader.get_weights_col_packed(weights, prefix, block_sizes) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights_col(weights, prefixes, dim) + + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights(weights, prefixes, dim) + + def get_weights_row(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights_row(weights, prefix) + + def _get_target_loaders( + self, quantization_config: QuantizationConfig + ) -> Dict[str, WeightsLoader]: + """ + A compressed-tensors checkpoint can use different quantizations + for different targets. This method returns a dictionary with a + loader per target. + """ + + loaders: Dict[str, WeightsLoader] = {} + + format = quantization_config.format + + for group_name, group in quantization_config.config_groups.items(): + # The group configuration can be a string, but does that ever + # happen in a serialized quantization config? + assert isinstance(group, QuantizationScheme) + + loader = self._create_loader_for_group(format, group_name, group) + + # A quantized parameter group can have multiple targets, add the + # loader for all the targets. + for target in group.targets: + if target in loaders: + raise ValueError( + f"Target '{target} has multiple configured loaders'" + ) + loaders[target] = loader + + return loaders + + def _create_loader_for_group( + self, format: str, group_name: str, group: QuantizationScheme + ) -> WeightsLoader: + """ + Find and create a loader for the group with the given quantization + scheme. + """ + # NOTE: we ignore group.output_activations because we don't support + # output quantization yet. + + input_activations = group.input_activations + weights = group.weights + if ( + format + in { + CompressionFormat.float_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.FLOAT + and weights.num_bits == 8 + ): + # FP W8A8 or W8A16. + return W8ANFpLoader(input_activations=input_activations, weights=weights) + else: + raise ValueError( + f"Group '{group_name}' has unsupported compressed-tensors configurtion" + ) + + def _lookup_loader(self, prefix: str) -> WeightsLoader: + """ + Look up the loader to use for a given parameter name (prefix). + """ + + if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: + return DefaultWeightsLoader(UnquantizedWeight) + + # We currently only handle linear layers, so unconditionally pass + # a `Linear` instance. + targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) + if len(targets) == 0: + raise ValueError( + f"Cannot find compressed-tensors target for prefix: {prefix}" + ) + return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py new file mode 100644 index 00000000..6eb00387 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -0,0 +1,253 @@ +from typing import List, Optional, Union + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, +) +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class W8ANFpLoader(WeightsLoader): + """ + Loader for W8A8/W8A16 FP compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_activations: Optional[QuantizationArgs], + weights: QuantizationArgs, + ): + assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 + + # We ignore the `strategy` option which sets the scales to be + # per-tensor, per-channel or per-token. What scales are supported + # is dependent on the kernels used (e.g. cutlass can do tokenwise, + # Torch cannot, and FP8-Marlin does not quantize inputs at all). + # So, instead we try to use the best-possible configuration. + + self.load_weight_scale = not weights.dynamic + self.load_input_scale = ( + input_activations is not None and not input_activations.dynamic + ) + self.force_w8a16 = ( + input_activations is not None and input_activations.num_bits == 16 + ) + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + quantization_type = f"W8A{16 if self.force_w8a16 else 8}" + + return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + weight_scale = None + if self.load_weight_scale: + weight_scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + + if self.load_weight_scale: + weight_scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 44d30202..8de335ac 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -207,7 +207,7 @@ def requantize_with_max_scale( for idx, logical_width in enumerate(logical_widths): end = start + logical_width weight_dq = per_tensor_dequantize( - weight[start:end, :], weight_scale[idx], dtype + weight[start:end, :], weight_scale[start:end, :], dtype ) weight[start:end, :], max_w_scale_normalized = fp8_quantize( weight_dq, max_w_scale @@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ) # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ] scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) for p, shape in zip(prefixes, shapes) @@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) - logical_widths = [x[0] for x in shapes] - w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths, weights.dtype - ) - return Fp8Weight( weight=w, weight_scale=scale, @@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader): ) scale = [ - weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) - for p in prefixes + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + input_scale = [ weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) for p in prefixes @@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) - logical_widths = [x[0] for x in shapes] - w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths, weights.dtype - ) - return Fp8Weight( weight=w, weight_scale=scale, @@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader): weight_block_size=self.weight_block_size, ) - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module): weight_block_size=weight_block_size, ) - @classmethod - def get_shared_device_identity(cls, device): - # Input scaling factors are no longer optional in _scaled_mm starting - # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale - if device not in cls._device_identity_cache: - cls._device_identity_cache[device] = torch.ones(1, device=device) - return cls._device_identity_cache[device] - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.weight_block_size is not None: + if self.weight_block_size is not None or self.input_scale is None: return apply_block_fp8_linear_hpu_dynamic( input, self.qweight, self.scale, self.input_scale, self.bias ) - qinput, scale = fp8_quantize( - input, - self.input_scale, - scale_upper_bound=self.scale_upper_bound, - scalar=True, - ) - - output = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, + x_fp8 = torch.ops.hpu.cast_to_fp8_v2( + input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn + )[0] + return torch.ops.hpu.fp8_gemm_v2( + A=x_fp8, + trans_A=False, + B=self.qweight, + trans_B=True, + D=None, + out_dtype=input.dtype, + A_scale_inv=self.input_scale, + B_scale_inv=self.scale, bias=self.bias, + accumulate=False, ) - if isinstance(output, tuple) and len(output) == 2: - output = output[0] - - return output - def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1) + return scale.reshape(-1).expand(shape[0]) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index babf3d4b..96b120b2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim) + try: + qweight = torch.cat( + [weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1) + + self._get_gptq_params(weights) + + qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1) + + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + def get_weights_row(self, weights: Weights, prefix: str): self._get_gptq_params(weights) diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 32a9b9bf..d2d7c836 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -309,7 +309,6 @@ class ModelType(enum.Enum): "name": "Qwen 3 Moe", "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", } - GALACTICA = { "type": "galactica", "name": "Galactica", @@ -832,7 +831,6 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif model_type == MLLAMA: return FlashMllamaCausalLM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 98994e48..11864c52 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn import torch.nn.functional as F +import habana_frameworks.torch as htorch from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module): ) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( @@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module): position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index cce64196..1b4af58a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -14,6 +14,7 @@ from typing import Optional, Tuple, List import torch from torch import nn +import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, @@ -26,14 +27,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, SpeculativeHead, - FastLinear, ) -from text_generation_server.utils.import_utils import ( - synchronize, - get_free_memory, -) -from loguru import logger -from text_generation_server.utils.log import log_master from text_generation_server.layers.layernorm import ( @@ -53,12 +47,9 @@ class Qwen3Attention(nn.Module): self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) - config.num_key_value_heads = getattr( - config, "num_key_value_heads", config.num_attention_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads ) - # self.num_key_value_groups = ( - # config.num_attention_heads // config.num_key_value_heads - # ) self.num_heads = config.num_attention_heads self.attention_dropout = config.attention_dropout self.softmax_scale = self.head_dim**-0.5 @@ -75,65 +66,16 @@ class Qwen3Attention(nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - # self.num_key_value_heads = config.num_key_value_heads - if config.num_key_value_heads > weights.process_group.size(): - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) - else: - self.num_key_value_heads = config.num_key_value_heads - - self.query_proj = TensorParallelColumnLinear.load( + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load_multi( config, - prefix=f"{prefix}.q_proj", + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, weights=weights, bias=False, ) - if self.num_key_value_heads != config.num_key_value_heads: - self.key_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=False, - ) - self.value_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=False, - ) - else: - self.key_proj = FastLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=False, - ) - self.value_proj = FastLinear.load( - config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=False, - ) - # self.key_proj = TensorParallelColumnLinear.load( - # config, - # prefix=f"{prefix}.k_proj", - # weights=weights, - # bias=False, - # ) - # self.value_proj = TensorParallelColumnLinear.load( - # config, - # prefix=f"{prefix}.v_proj", - # weights=weights, - # bias=False, - # ) - # self.query_key_value = TensorParallelColumnLinear.load_multi( - # config, - # prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - # dim=0, - # weights=weights, - # bias=False, - # ) self.kv_scales = get_kv_scales(weights, f"{prefix}") @@ -144,10 +86,11 @@ class Qwen3Attention(nn.Module): bias=False, ) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_key_value_groups) + ).repeat_interleave(self.num_groups) self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -182,45 +125,21 @@ class Qwen3Attention(nn.Module): seqlen, hpu_attention_meta, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - print(f"hidden_states shape: {hidden_states.shape}") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - # qkv = self.query_key_value(hidden_states) - # print(f"qkv shape: {qkv.shape}") - # print(f"self.head_dim: {self.head_dim}") - # print(f"self.num_heads: {self.num_heads}") - # print(f"self.num_key_value_heads: {self.num_key_value_heads}") - # query_states, key_states, value_states = qkv.split( - # [ - # self.head_dim * self.num_heads, - # self.head_dim * self.num_key_value_heads, - # self.head_dim * self.num_key_value_heads, - # ], - # dim=1, - # ) - synchronize(hidden_states.device) - real_free_memory = get_free_memory(hidden_states.device, 1) - log_master( - logger.debug, - f"Attention forward1 Free memory real: {real_free_memory / 1e9:.2f}GB", + qkv = self.query_key_value(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=1, ) - query_states = self.query_proj(hidden_states) - key_states = self.key_proj(hidden_states) - value_states = self.value_proj(hidden_states) query_states, _ = self.q_norm(query_states.view(hidden_shape)) key_states, _ = self.k_norm(key_states.view(hidden_shape)) value_states = value_states.view(hidden_shape) - print(f"query_states shape: {query_states.shape}") - print(f"key_states shape: {key_states.shape}") - print(f"value_states shape: {value_states.shape}") - synchronize(hidden_states.device) - real_free_memory = get_free_memory(hidden_states.device, 1) - log_master( - logger.debug, - f"Attention forward2 Free memory real: {real_free_memory / 1e9:.2f}GB", - ) - self.rotary_emb(query_states, key_states, cos, sin) kv_cache.store( @@ -257,7 +176,6 @@ class Qwen3Attention(nn.Module): ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - print(f"attn_output shape: {attn_output.shape}") return self.o_proj(attn_output) @@ -359,6 +277,11 @@ class Qwen3Model(nn.Module): ) residual = None + + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, @@ -371,6 +294,8 @@ class Qwen3Model(nn.Module): seqlen, hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index 022a4897..192963c4 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision): def get_loader( quantize: Optional[str], model_id: str, revision: Optional[str] ) -> WeightsLoader: + if quantize == "compressed-tensors": + config = _get_config_json(model_id, revision, "config.json") + from text_generation_server.layers.compressed_tensors import ( + CompressedTensorsLoader, + ) + + return CompressedTensorsLoader(config) quantizer_config = _get_quantizer_config(model_id, revision) if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 1628a00b..c8b29204 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator { tokens: u32, _prefill_tokens: Option>>, ) -> Option { + let mut tokens = tokens; + if self.is_hpu_device { + // need 1 slot for ping-pong optimization + tokens += 1; + } // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match self.window_size { @@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator { let required_blocks = tokens.div_ceil(self.block_size); (required_blocks, repeats) }; - - let mut tokens = tokens as usize; + let tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); - // need 1 slot for ping-pong optimization - tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);