From 7f346a88e3eeb6301126502874085c38ae6260ea Mon Sep 17 00:00:00 2001 From: yuanwu Date: Fri, 6 Jun 2025 06:14:01 +0000 Subject: [PATCH] Fix the crash issue of Qwen/Qwen3-235B-A22B Signed-off-by: yuanwu --- .../flash_qwen3_moe_modeling.py | 277 ++++++++---------- 1 file changed, 129 insertions(+), 148 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py index 1a264fbe..28a96523 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -19,9 +19,12 @@ import torch from torch import nn import torch.nn.functional as F from text_generation_server.layers.attention import ( + attention, + paged_attention, Seqlen, HPUPagedAttentionMetadata, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers import ( TensorParallelEmbedding, @@ -37,40 +40,7 @@ from text_generation_server.layers.layernorm import ( from .flash_qwen2_modeling import Qwen2MLP from .flash_qwen3_modeling import Qwen3Attention from transformers.activations import ACT2FN - - -# import torch -# import torch.nn.functional as F -# from torch import nn - -# from ...activations import ACT2FN -# from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -# from ...generation import GenerationMixin -# from ...integrations import use_kernel_forward_from_hub -# from ...modeling_attn_mask_utils import AttentionMaskConverter -# from ...modeling_flash_attention_utils import FlashAttentionKwargs -# from ...modeling_outputs import ( -# BaseModelOutputWithPast, -# CausalLMOutputWithPast, -# MoeCausalLMOutputWithPast, -# MoeModelOutputWithPast, -# QuestionAnsweringModelOutput, -# SequenceClassifierOutputWithPast, -# TokenClassifierOutput, -# ) -# from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -# from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -# from ...processing_utils import Unpack -# from ...utils import ( -# LossKwargs, -# add_code_sample_docstrings, -# add_start_docstrings, -# add_start_docstrings_to_model_forward, -# can_return_tuple, -# is_torch_flex_attn_available, -# logging, -# replace_return_docstrings, -# ) +from text_generation_server.layers.rotary import PositionRotaryEmbedding def rotate_half(x): @@ -107,132 +77,131 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + def __init__(self, config, prefix, weights, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) + self.q_proj = FastLinear.load( + config, f"{prefix}.q_proj", weights, bias=config.attention_bias + ) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.k_proj = FastLinear.load( + config, f"{prefix}.k_proj", weights, bias=config.attention_bias + ) + self.v_proj = FastLinear.load( + config, f"{prefix}.v_proj", weights, bias=config.attention_bias + ) + self.o_proj = FastLinear.load( + config, f"{prefix}.o_proj", weights, bias=config.attention_bias + ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=dropout, training=module.training - ) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_dim, + base=config.rope_theta, + device=weights.device, + ) - return attn_output, attn_weights + self.q_norm = FastRMSNorm.load( + prefix=f"{prefix}.q_norm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.k_norm = FastRMSNorm.load( + prefix=f"{prefix}.k_norm", + weights=weights, + eps=config.rms_norm_eps, + ) -# class Qwen3MoeAttention(nn.Module): -# """Multi-headed attention from 'Attention Is All You Need' paper""" + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) -# def __init__(self, config: Qwen3MoeConfig, layer_idx: int): -# super().__init__() -# self.config = config -# self.layer_idx = layer_idx -# self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) -# self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads -# self.scaling = self.head_dim**-0.5 -# self.attention_dropout = config.attention_dropout -# self.is_causal = True + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_key_value_groups) -# self.q_proj = nn.Linear( -# config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias -# ) -# self.k_proj = nn.Linear( -# config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias -# ) -# self.v_proj = nn.Linear( -# config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias -# ) -# self.o_proj = nn.Linear( -# config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias -# ) -# self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! -# self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape -# self.sliding_window = config.sliding_window -# if not ( -# self.config.use_sliding_window -# and getattr(self.config, "sliding_window", None) is not None -# and self.layer_idx >= self.config.max_window_layers -# ): -# self.sliding_window = None + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None -# def forward( -# self, -# hidden_states: torch.Tensor, -# position_embeddings: Tuple[torch.Tensor, torch.Tensor], -# attention_mask: Optional[torch.Tensor], -# cache_position: Optional[torch.LongTensor] = None, -# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: -# input_shape = hidden_states.shape[:-1] -# hidden_shape = (*input_shape, -1, self.head_dim) + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + slots, + seqlen, + hpu_attention_meta, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) -# query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) -# key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) -# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)) + key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + value_states = self.v_proj(hidden_states).view(hidden_shape) -# cos, sin = position_embeddings -# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + self.rotary_emb(query_states, key_states, cos, sin) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) -# if past_key_value is not None: -# # sin and cos are specific to RoPE models; cache_position needed for the static cache -# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} -# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + kv_cache.store( + key=key_states, + value=value_states, + slots=slots, + kv_scales=self.kv_scales, + ) -# attention_interface: Callable = eager_attention_forward -# if self.config._attn_implementation != "eager": -# if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): -# logger.warning_once( -# "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " -# 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' -# ) -# else: -# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + # Prefill + if cu_seqlen_prefill is not None: + # sdpa + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + softmax_scale=self.scaling, + window_size_left=self.max_past, + num_key_value_groups=self.num_key_value_groups, + ) + # Decode + else: + attn_output = paged_attention( + query_states, + kv_cache, + self.kv_head_mapping, + self.scaling, + seqlen, + kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, + ) -# attn_output, attn_weights = attention_interface( -# self, -# query_states, -# key_states, -# value_states, -# attention_mask, -# dropout=0.0 if not self.training else self.attention_dropout, -# scaling=self.scaling, -# sliding_window=self.sliding_window, # diff with Llama -# **kwargs, -# ) - -# attn_output = attn_output.reshape(*input_shape, -1).contiguous() -# attn_output = self.o_proj(attn_output) -# return attn_output, attn_weights + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output class Qwen3MoE(nn.Module): @@ -415,9 +384,21 @@ class Qwen3MoeDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Qwen3Attention( - config, prefix=f"{prefix}.self_attn", weights=weights, layer_idx=layer_idx - ) + if config.num_key_value_heads // weights.process_group.size() > 0: + self.self_attn = Qwen3Attention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + else: + self.self_attn = Qwen3MoeAttention( + config, + prefix=f"{prefix}.self_attn", + weights=weights, + layer_idx=layer_idx, + ) + moe_layer_cls = ( SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer )