Add Qwen3 for Gaudi backend (#3229)

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-05-23 14:58:35 +08:00 committed by GitHub
parent f58d7cf50e
commit 1883a62a94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 386 additions and 2 deletions

View File

@ -109,6 +109,9 @@ try:
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
Qwen3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
@ -293,6 +296,12 @@ class ModelType(enum.Enum):
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
QWEN3 = {
"type": "qwen3",
"name": "Qwen 3",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
@ -791,6 +800,18 @@ def get_model(
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
elif model_type == QWEN3:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,

View File

@ -22,6 +22,7 @@ import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
import habana_frameworks.torch as htorch
from transformers.cache_utils import Cache
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
@ -567,6 +568,9 @@ class Llama4TextModel(nn.Module):
)
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states = layer(
@ -582,6 +586,8 @@ class Llama4TextModel(nn.Module):
position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)

View File

@ -0,0 +1,356 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, List
import torch
from torch import nn
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
SpeculativeHead,
)
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from .flash_qwen2_modeling import Qwen2MLP
from text_generation_server.layers.rotary import PositionRotaryEmbedding
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, prefix, weights, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_dim,
base=config.rope_theta,
device=weights.device,
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.sliding_window = config.sliding_window
if not (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = None
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=1,
)
query_states, _ = self.q_norm(query_states.view(hidden_shape))
key_states, _ = self.k_norm(key_states.view(hidden_shape))
value_states = value_states.view(hidden_shape)
self.rotary_emb(query_states, key_states, cos, sin)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class Qwen3DecoderLayer(nn.Module):
def __init__(self, config, prefix, weights, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(
config=config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
)
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> torch.Tensor:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3Model(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList(
[
Qwen3DecoderLayer(
config=config,
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
# add hidden states from the last decoder layer
return hidden_states
class Qwen3ForCausalLM(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen3Model(config=config, prefix="model", weights=weights)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -1412,6 +1412,7 @@ class FlashCausalLM(Model):
aliases=aliases,
weights_loader=weights_loader,
)
print(f"weights: {weights}")
prefix = None
model = model_class(prefix, config, weights)

View File

@ -10,8 +10,8 @@ fi
# Check if ATTENTION environment variable is set to paged
if [[ "$ATTENTION" == "paged" ]]; then
# Check if Llama-4 is in the command line arguments
if [[ "$*" == *"Llama-4"* ]]; then
echo 'ATTENTION=paged and Llama-4 detected'
if [[ "$*" == *"Llama-4"* || "$*" == *"Qwen3"* ]]; then
echo 'ATTENTION=paged and Llama-4 or Qwen3 detected'
pip install git+https://github.com/huggingface/transformers.git@29338949
fi
fi