From 638714f964349155bbcd2892582d2c061dd81920 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Tue, 13 May 2025 07:42:22 +0000 Subject: [PATCH] Add Qwen3 Signed-off-by: yuanwu --- Dockerfile_gaudi | 4 +- .../server/text_generation_server/cli.py | 4 +- .../layers/tensor_parallel.py | 1 + .../text_generation_server/models/__init__.py | 21 ++ .../custom_modeling/flash_qwen3_modeling.py | 349 ++++++++++++++++++ .../models/flash_causal_lm.py | 1 + backends/gaudi/tgi-entrypoint.sh | 4 +- 7 files changed, 378 insertions(+), 6 deletions(-) create mode 100644 backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 06073fe4..20c03cb3 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -122,5 +122,5 @@ ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh -ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] +#ENTRYPOINT ["/tgi-entrypoint.sh"] +#CMD ["--json-output"] diff --git a/backends/gaudi/server/text_generation_server/cli.py b/backends/gaudi/server/text_generation_server/cli.py index 53837ef7..af908472 100644 --- a/backends/gaudi/server/text_generation_server/cli.py +++ b/backends/gaudi/server/text_generation_server/cli.py @@ -57,7 +57,7 @@ def serve( ), "MASTER_PORT must be set when sharded is True" # Remove default handler - logger.remove() + # logger.remove() logger.add( sys.stdout, format="{message}", @@ -193,7 +193,7 @@ def download_weights( merge_lora: bool = False, ): # Remove default handler - logger.remove() + # logger.remove() logger.add( sys.stdout, format="{message}", diff --git a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py index 8f19174f..ae60e7aa 100644 --- a/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py +++ b/backends/gaudi/server/text_generation_server/layers/tensor_parallel.py @@ -155,6 +155,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): + print(f"bias: {bias}") if config.quantize == "exl2": linears = [] for prefix in prefixes: diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 6ca7b567..99317a20 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -109,6 +109,9 @@ try: from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_qwen3_modeling import ( + Qwen3ForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) @@ -293,6 +296,12 @@ class ModelType(enum.Enum): "name": "Qwen 2.5 VL", "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", } + QWEN3 = { + "type": "qwen3", + "name": "Qwen 3", + "url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f", + } + GALACTICA = { "type": "galactica", "name": "Galactica", @@ -785,6 +794,18 @@ def get_model( config_class=Qwen2_5_VLConfig, processor_class=Qwen2_5_VLProcessor, ) + elif model_type == QWEN3: + return FlashCausalLM( + model_id=model_id, + model_class=Qwen3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) elif model_type == MLLAMA: return FlashMllamaCausalLM( model_id=model_id, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py new file mode 100644 index 00000000..2c8662eb --- /dev/null +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -0,0 +1,349 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import torch +from torch import nn +from text_generation_server.layers.attention import ( + paged_attention, + attention, + Seqlen, + HPUPagedAttentionMetadata, +) +from text_generation_server.layers.attention.kv_cache import get_kv_scales +from text_generation_server.layers import ( + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, + SpeculativeHead, +) + + +from text_generation_server.layers.layernorm import ( + FastRMSNorm, +) +from .flash_qwen2_modeling import Qwen2MLP +from text_generation_server.layers.rotary import PositionRotaryEmbedding + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, prefix, weights, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.num_heads = config.num_attention_heads + self.attention_dropout = config.attention_dropout + self.softmax_scale = self.head_dim**-0.5 + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_dim, + base=config.rope_theta, + device=weights.device, + ) + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + qkv = self.query_key_value(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=1, + ) + + query_states, _ = self.q_norm(query_states.view(hidden_shape)) + key_states, _ = self.k_norm(key_states.view(hidden_shape)) + value_states = value_states.view(hidden_shape) + self.rotary_emb(query_states, key_states, cos, sin) + + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config, prefix, weights, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention( + config=config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, _ = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states, _ = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Qwen3Model(nn.Module): + def __init__(self, config, prefix: str, weights): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList( + [ + Qwen3DecoderLayer( + config=config, + prefix=f"{prefix}.layers.{layer_idx}", + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + ) -> torch.Tensor: + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, + ) + + residual = None + for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + slots, + seqlen, + hpu_attention_meta, + ) + + hidden_states, _ = self.norm(hidden_states) + + # add hidden states from the last decoder layer + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + + def __init__(self, prefix: str, config, weights): + super().__init__() + self.model = Qwen3Model(config=config, prefix="model", weights=weights) + self.vocab_size = config.vocab_size + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + + self.lm_head = SpeculativeHead.load( + config, + prefix=f"{prefix}.{suffix}" if prefix else suffix, + weights=weights, + ) + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + slots: torch.Tensor, + seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + inputs_embeds = self.embed_tokens(input_ids) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states = self.model( + inputs_embeds, + position_ids, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + + return logits diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ad585172..976e1a65 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1391,6 +1391,7 @@ class FlashCausalLM(Model): aliases=aliases, weights_loader=weights_loader, ) + print(f"weights: {weights}") prefix = None model = model_class(prefix, config, weights) diff --git a/backends/gaudi/tgi-entrypoint.sh b/backends/gaudi/tgi-entrypoint.sh index deb64382..d787ea8e 100644 --- a/backends/gaudi/tgi-entrypoint.sh +++ b/backends/gaudi/tgi-entrypoint.sh @@ -10,8 +10,8 @@ fi # Check if ATTENTION environment variable is set to paged if [[ "$ATTENTION" == "paged" ]]; then # Check if Llama-4 is in the command line arguments - if [[ "$*" == *"Llama-4"* ]]; then - echo 'ATTENTION=paged and Llama-4 detected' + if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then + echo 'ATTENTION=paged and Llama-4 or Qwen3 detected' pip install git+https://github.com/huggingface/transformers.git@29338949 fi fi