From 1883a62a94fe13372fd7c012d6e53ee11c4ac048 Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Fri, 23 May 2025 14:58:35 +0800 Subject: [PATCH 1/2] Add Qwen3 for Gaudi backend (#3229) Signed-off-by: yuanwu --- .../text_generation_server/models/__init__.py | 21 ++ .../custom_modeling/flash_llama4_modeling.py | 6 + .../custom_modeling/flash_qwen3_modeling.py | 356 ++++++++++++++++++ .../models/flash_causal_lm.py | 1 + backends/gaudi/tgi-entrypoint.sh | 4 +- 5 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index a9a1d0b7..c46c79fb 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -109,6 +109,9 @@ try: from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_qwen3_modeling import ( + Qwen3ForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) @@ -293,6 +296,12 @@ class ModelType(enum.Enum): "name": "Qwen 2.5 VL", "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", } + QWEN3 = { + "type": "qwen3", + "name": "Qwen 3", + "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", + } + GALACTICA = { "type": "galactica", "name": "Galactica", @@ -791,6 +800,18 @@ def get_model( config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, ) + elif model_type == QWEN3: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) elif model_type == MLLAMA: return FlashMllamaCausalLM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 98994e48..11864c52 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,6 +22,7 @@ import torch.utils.checkpoint from torch import nn import torch.nn.functional as F +import habana_frameworks.torch as htorch from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS @@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module): ) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() for i, layer in enumerate(self.layers): hidden_states = layer( @@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module): position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) + if lazy_mode: + htorch.core.mark_step() hidden_states, _ = self.norm(hidden_states) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py new file mode 100644 index 00000000..66a17877 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -0,0 +1,356 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import torch +from torch import nn +import habana_frameworks.torch as htorch +from text_generation_server.layers.attention import ( + paged_attention, + attention, + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import get_kv_scales +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, + SpeculativeHead, +) + + +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from .flash_qwen2_modeling import Qwen2MLP +from text_generation_server.layers.rotary import PositionRotaryEmbedding + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, prefix, weights, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_heads = config.num_attention_heads + self.attention_dropout = config.attention_dropout + self.softmax_scale = self.head_dim**-0.5 + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_dim, + base=config.rope_theta, + device=weights.device, + ) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + qkv = self.query_key_value(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=1, + ) + + query_states, _ = self.q_norm(query_states.view(hidden_shape)) + key_states, _ = self.k_norm(key_states.view(hidden_shape)) + value_states = value_states.view(hidden_shape) + self.rotary_emb(query_states, key_states, cos, sin) + + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config, prefix, weights, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention( + config=config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, _ = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3Model(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config=config, + prefix=f"{prefix}.layers.{layer_idx}", + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, + ) + + residual = None + lazy_mode = htorch.utils.internal.is_lazy() + if lazy_mode: + htorch.core.mark_step() + + for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + if lazy_mode: + htorch.core.mark_step() + + hidden_states, _ = self.norm(hidden_states) + + # add hidden states from the last decoder layer + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + + def __init__(self, prefix: str, config, weights): + super().__init__() + self.model = Qwen3Model(config=config, prefix="model", weights=weights) + self.vocab_size = config.vocab_size + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + weights=weights, + ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + inputs_embeds = self.embed_tokens(input_ids) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + + return logits diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f8abe5ad..685fd8a9 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1412,6 +1412,7 @@ class FlashCausalLM(Model): aliases=aliases, weights_loader=weights_loader, ) + print(f"weights: {weights}") prefix = None model = model_class(prefix, config, weights) diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index deb64382..d787ea8e 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -10,8 +10,8 @@ fi # Check if ATTENTION environment variable is set to paged if [[ "$ATTENTION" == "paged" ]]; then # Check if Llama-4 is in the command line arguments - if [[ "$*" == *"Llama-4"* ]]; then - echo 'ATTENTION=paged and Llama-4 detected' + if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then + echo 'ATTENTION=paged and Llama-4 or Qwen3 detected' pip install git+https://github.com/huggingface/transformers.git@29338949 fi fi From f14044009a59e6524a0a317c3a455d188997dbb4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 28 May 2025 20:54:20 +0800 Subject: [PATCH 2/2] fp8 compressed tensors w8a8 support for Gaudi backend (#3242) Signed-off-by: Wang, Yi A --- Dockerfile_gaudi | 1 + .../server/text_generation_server/cli.py | 2 + .../layers/compressed_tensors/__init__.py | 3 + .../layers/compressed_tensors/loader.py | 169 ++++++++++++ .../layers/compressed_tensors/w8an_fp.py | 253 ++++++++++++++++++ .../text_generation_server/layers/fp8.py | 104 ++++--- .../layers/gptq/__init__.py | 57 ++++ .../models/flash_causal_lm.py | 1 - .../utils/quantization.py | 7 + backends/v3/src/block_allocator.rs | 10 +- 10 files changed, 548 insertions(+), 59 deletions(-) create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py create mode 100644 backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index c4164556..442eb6b7 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -99,6 +99,7 @@ RUN cd server && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ pip install . --no-cache-dir RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix +RUN pip install compressed-tensors==0.9.1 # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index b1a41534..d4445a13 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -19,6 +19,7 @@ class Quantization(str, Enum): gptq = "gptq" awq = "awq" fp8 = "fp8" + compressed_tensors = "compressed-tensors" class Dtype(str, Enum): @@ -109,6 +110,7 @@ def serve( "gptq", "awq", "fp8", + "compressed-tensors", }: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py new file mode 100644 index 00000000..507af706 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py @@ -0,0 +1,3 @@ +from .loader import CompressedTensorsLoader + +__all__ = ["CompressedTensorsLoader"] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py new file mode 100644 index 00000000..0dccf34a --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py @@ -0,0 +1,169 @@ +from typing import Any, Dict, List, Union + +from compressed_tensors import QuantizationConfig, QuantizationStatus +from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationType, + find_name_or_class_matches, +) +from loguru import logger +from pydantic import ValidationError +from torch import nn + +from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, + WeightsLoader, +) + +# compressed-tensors can match modules as quantization targets. However, +# they need to be objects rather than classes or class names. Since we +# need to match `Linear` targets, make an instance that can be re-used. +_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0) + + +class CompressedTensorsLoader(WeightsLoader): + """Loader for checkpoints stored in the compressed-tensors format.""" + + def __init__(self, config: Dict[str, Any]): + quantization_config_raw = config.get("quantization_config") + if quantization_config_raw is None: + # `compression_config` was renamed to `quantization_config`; support + # retained for backward compatibility. + quantization_config_raw = config.get("compression_config") + if quantization_config_raw is None: + raise ValueError( + "Checkpoint does not have compressed-tensors configuration" + ) + + try: + quantization_config = QuantizationConfig.model_validate( + quantization_config_raw + ) + except ValidationError as e: + raise ValueError("Cannot parse compressed-tensors configuration") from e + + if quantization_config.quantization_status not in ( + QuantizationStatus.COMPRESSED, + QuantizationStatus.FROZEN, + ): + raise ValueError( + f"Model quantization was not finished, status was: {quantization_config.quantization_status}" + ) + + self.ignore = ( + quantization_config.ignore if quantization_config.ignore is not None else [] + ) + self.loaders = self._get_target_loaders(quantization_config) + + for target, loader in self.loaders.items(): + log_once( + logger.info, + f"Using {loader} for compressed-tensors target '{target}'", + ) + + def get_weights(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights(weights, prefix) + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + loader = self._lookup_loader(prefix) + return loader.get_weights_col_packed(weights, prefix, block_sizes) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights_col(weights, prefixes, dim) + + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + loader = self._lookup_loader(prefixes[0]) + return loader.get_multi_weights(weights, prefixes, dim) + + def get_weights_row(self, weights: Weights, prefix: str): + loader = self._lookup_loader(prefix) + return loader.get_weights_row(weights, prefix) + + def _get_target_loaders( + self, quantization_config: QuantizationConfig + ) -> Dict[str, WeightsLoader]: + """ + A compressed-tensors checkpoint can use different quantizations + for different targets. This method returns a dictionary with a + loader per target. + """ + + loaders: Dict[str, WeightsLoader] = {} + + format = quantization_config.format + + for group_name, group in quantization_config.config_groups.items(): + # The group configuration can be a string, but does that ever + # happen in a serialized quantization config? + assert isinstance(group, QuantizationScheme) + + loader = self._create_loader_for_group(format, group_name, group) + + # A quantized parameter group can have multiple targets, add the + # loader for all the targets. + for target in group.targets: + if target in loaders: + raise ValueError( + f"Target '{target} has multiple configured loaders'" + ) + loaders[target] = loader + + return loaders + + def _create_loader_for_group( + self, format: str, group_name: str, group: QuantizationScheme + ) -> WeightsLoader: + """ + Find and create a loader for the group with the given quantization + scheme. + """ + # NOTE: we ignore group.output_activations because we don't support + # output quantization yet. + + input_activations = group.input_activations + weights = group.weights + if ( + format + in { + CompressionFormat.float_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.FLOAT + and weights.num_bits == 8 + ): + # FP W8A8 or W8A16. + return W8ANFpLoader(input_activations=input_activations, weights=weights) + else: + raise ValueError( + f"Group '{group_name}' has unsupported compressed-tensors configurtion" + ) + + def _lookup_loader(self, prefix: str) -> WeightsLoader: + """ + Look up the loader to use for a given parameter name (prefix). + """ + + if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0: + return DefaultWeightsLoader(UnquantizedWeight) + + # We currently only handle linear layers, so unconditionally pass + # a `Linear` instance. + targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys()) + if len(targets) == 0: + raise ValueError( + f"Cannot find compressed-tensors target for prefix: {prefix}" + ) + return self.loaders[targets[0]] diff --git a/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py new file mode 100644 index 00000000..6eb00387 --- /dev/null +++ b/backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -0,0 +1,253 @@ +from typing import List, Optional, Union + +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import ( + Fp8Weight, + _load_scalar_or_matrix_scale, + requantize_with_max_scale, +) +from text_generation_server.utils.weights import Weights, WeightsLoader + + +class W8ANFpLoader(WeightsLoader): + """ + Loader for W8A8/W8A16 FP compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_activations: Optional[QuantizationArgs], + weights: QuantizationArgs, + ): + assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8 + + # We ignore the `strategy` option which sets the scales to be + # per-tensor, per-channel or per-token. What scales are supported + # is dependent on the kernels used (e.g. cutlass can do tokenwise, + # Torch cannot, and FP8-Marlin does not quantize inputs at all). + # So, instead we try to use the best-possible configuration. + + self.load_weight_scale = not weights.dynamic + self.load_input_scale = ( + input_activations is not None and not input_activations.dynamic + ) + self.force_w8a16 = ( + input_activations is not None and input_activations.num_bits == 16 + ) + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + quantization_type = f"W8A{16 if self.force_w8a16 else 8}" + + return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + weight_scale = None + if self.load_weight_scale: + weight_scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False) + if input_scale.numel() > 1: + input_scale = weights.get_packed_sharded( + f"{prefix}.input_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + input_scale = input_scale.reshape(-1).max() + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int): + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes] + shapes = [x.shape for x in w] + + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) + + weight_scale = None + + if self.load_weight_scale: + weight_scale = [ + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = [ + weights.get_tensor(f"{p}.input_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) + if weights.has_tensor(f"{p}.input_scale") + ] + assert len(input_scale) == 0 or len(input_scale) == len(prefixes) + input_scale = ( + torch.cat(input_scale, dim=0).reshape(-1).max() + if len(input_scale) != 0 + else None + ) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, weight_scale = requantize_with_max_scale( + w, + weight_scale.unsqueeze(-1).to(weights.device), + logical_widths, + weights.dtype, + ) + + input_scale = None + if self.load_input_scale: + input_scale = weights.get_tensor( + f"{prefix}.input_scale", to_dtype=False + ).reshape(-1) + + return Fp8Weight( + weight=w, + weight_scale=weight_scale, + input_scale=input_scale, + dtype=weights.dtype, + force_w8a16=self.force_w8a16, + ) diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 44d30202..8de335ac 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -207,7 +207,7 @@ def requantize_with_max_scale( for idx, logical_width in enumerate(logical_widths): end = start + logical_width weight_dq = per_tensor_dequantize( - weight[start:end, :], weight_scale[idx], dtype + weight[start:end, :], weight_scale[start:end, :], dtype ) weight[start:end, :], max_w_scale_normalized = fp8_quantize( weight_dq, max_w_scale @@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ) # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader): block_sizes=block_sizes, to_dtype=False, ) + scale = scale.reshape(-1).expand(w.shape[0]) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader): to_dtype=False, ) input_scale = input_scale.reshape(-1).max() - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) return Fp8Weight( weight=w, @@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader): ] scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + input_scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) for p, shape in zip(prefixes, shapes) @@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) - logical_widths = [x[0] for x in shapes] - w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths, weights.dtype - ) - return Fp8Weight( weight=w, weight_scale=scale, @@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader): ) scale = [ - weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1) - for p in prefixes + weights.get_tensor(f"{p}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(shape[0]) + for p, shape in zip(prefixes, shapes) ] scale = torch.cat(scale, dim=0).reshape(-1) + logical_widths = [x[0] for x in shapes] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) + input_scale = [ weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1) for p in prefixes @@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader): else None ) - logical_widths = [x[0] for x in shapes] - w, scale = requantize_with_max_scale( - w, scale.to(weights.device), logical_widths, weights.dtype - ) - return Fp8Weight( weight=w, weight_scale=scale, @@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader): weight_block_size=self.weight_block_size, ) - scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + scale = ( + weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + .reshape(-1) + .expand(w.shape[0]) + ) + logical_widths = [w.shape[0]] + w, scale = requantize_with_max_scale( + w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype + ) input_scale = None if weights.has_tensor(f"{prefix}.input_scale"): @@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader): .reshape(-1) .max() ) - logical_widths = [w.shape[0]] - w, scale = requantize_with_max_scale( - w, scale.unsqueeze(0), logical_widths, weights.dtype - ) + return Fp8Weight( weight=w, weight_scale=scale, @@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module): weight_block_size=weight_block_size, ) - @classmethod - def get_shared_device_identity(cls, device): - # Input scaling factors are no longer optional in _scaled_mm starting - # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale - if device not in cls._device_identity_cache: - cls._device_identity_cache[device] = torch.ones(1, device=device) - return cls._device_identity_cache[device] - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.weight_block_size is not None: + if self.weight_block_size is not None or self.input_scale is None: return apply_block_fp8_linear_hpu_dynamic( input, self.qweight, self.scale, self.input_scale, self.bias ) - qinput, scale = fp8_quantize( - input, - self.input_scale, - scale_upper_bound=self.scale_upper_bound, - scalar=True, - ) - - output = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, + x_fp8 = torch.ops.hpu.cast_to_fp8_v2( + input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn + )[0] + return torch.ops.hpu.fp8_gemm_v2( + A=x_fp8, + trans_A=False, + B=self.qweight, + trans_B=True, + D=None, + out_dtype=input.dtype, + A_scale_inv=self.input_scale, + B_scale_inv=self.scale, bias=self.bias, + accumulate=False, ) - if isinstance(output, tuple) and len(output) == 2: - output = output[0] - - return output - def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): scale = weights.get_tensor(prefix, to_dtype=False) if scale.numel() > 1: scale = weights.get_sharded(prefix, dim=0, to_dtype=False) - return scale.reshape(-1) + return scale.reshape(-1).expand(shape[0]) diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index babf3d4b..96b120b2 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader): use_exllama=use_exllama, ) + def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim) + try: + qweight = torch.cat( + [weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1) + + self._get_gptq_params(weights) + + qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1) + + use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", + use_exllama=use_exllama, + ) + def get_weights_row(self, weights: Weights, prefix: str): self._get_gptq_params(weights) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 685fd8a9..f8abe5ad 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1412,7 +1412,6 @@ class FlashCausalLM(Model): aliases=aliases, weights_loader=weights_loader, ) - print(f"weights: {weights}") prefix = None model = model_class(prefix, config, weights) diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index 022a4897..192963c4 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision): def get_loader( quantize: Optional[str], model_id: str, revision: Optional[str] ) -> WeightsLoader: + if quantize == "compressed-tensors": + config = _get_config_json(model_id, revision, "config.json") + from text_generation_server.layers.compressed_tensors import ( + CompressedTensorsLoader, + ) + + return CompressedTensorsLoader(config) quantizer_config = _get_quantizer_config(model_id, revision) if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 1628a00b..c8b29204 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator { tokens: u32, _prefill_tokens: Option>>, ) -> Option { + let mut tokens = tokens; + if self.is_hpu_device { + // need 1 slot for ping-pong optimization + tokens += 1; + } // Apply window size let (required_blocks, repeats) = { let (tokens, repeats) = match self.window_size { @@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator { let required_blocks = tokens.div_ceil(self.block_size); (required_blocks, repeats) }; - - let mut tokens = tokens as usize; + let tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); - // need 1 slot for ping-pong optimization - tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);