mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Fix the crash issue of Qwen/Qwen3-235B-A22B
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
1a5ef906ae
commit
7f346a88e3
@ -19,9 +19,12 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
@ -37,40 +40,7 @@ from text_generation_server.layers.layernorm import (
|
|||||||
from .flash_qwen2_modeling import Qwen2MLP
|
from .flash_qwen2_modeling import Qwen2MLP
|
||||||
from .flash_qwen3_modeling import Qwen3Attention
|
from .flash_qwen3_modeling import Qwen3Attention
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
|
||||||
# import torch
|
|
||||||
# import torch.nn.functional as F
|
|
||||||
# from torch import nn
|
|
||||||
|
|
||||||
# from ...activations import ACT2FN
|
|
||||||
# from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|
||||||
# from ...generation import GenerationMixin
|
|
||||||
# from ...integrations import use_kernel_forward_from_hub
|
|
||||||
# from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|
||||||
# from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
||||||
# from ...modeling_outputs import (
|
|
||||||
# BaseModelOutputWithPast,
|
|
||||||
# CausalLMOutputWithPast,
|
|
||||||
# MoeCausalLMOutputWithPast,
|
|
||||||
# MoeModelOutputWithPast,
|
|
||||||
# QuestionAnsweringModelOutput,
|
|
||||||
# SequenceClassifierOutputWithPast,
|
|
||||||
# TokenClassifierOutput,
|
|
||||||
# )
|
|
||||||
# from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
||||||
# from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
||||||
# from ...processing_utils import Unpack
|
|
||||||
# from ...utils import (
|
|
||||||
# LossKwargs,
|
|
||||||
# add_code_sample_docstrings,
|
|
||||||
# add_start_docstrings,
|
|
||||||
# add_start_docstrings_to_model_forward,
|
|
||||||
# can_return_tuple,
|
|
||||||
# is_torch_flex_attn_available,
|
|
||||||
# logging,
|
|
||||||
# replace_return_docstrings,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
@ -107,132 +77,131 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
class Qwen3MoeAttention(nn.Module):
|
||||||
"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
||||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
|
||||||
)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
def __init__(self, config, prefix, weights, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = getattr(
|
||||||
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = (
|
||||||
|
config.num_attention_heads // config.num_key_value_heads
|
||||||
|
)
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
def eager_attention_forward(
|
self.q_proj = FastLinear.load(
|
||||||
module: nn.Module,
|
config, f"{prefix}.q_proj", weights, bias=config.attention_bias
|
||||||
query: torch.Tensor,
|
)
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
scaling: float,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
self.k_proj = FastLinear.load(
|
||||||
if attention_mask is not None:
|
config, f"{prefix}.k_proj", weights, bias=config.attention_bias
|
||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
)
|
||||||
attn_weights = attn_weights + causal_mask
|
self.v_proj = FastLinear.load(
|
||||||
|
config, f"{prefix}.v_proj", weights, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
self.o_proj = FastLinear.load(
|
||||||
|
config, f"{prefix}.o_proj", weights, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
query.dtype
|
config=config,
|
||||||
)
|
dim=self.head_dim,
|
||||||
attn_weights = nn.functional.dropout(
|
base=config.rope_theta,
|
||||||
attn_weights, p=dropout, training=module.training
|
device=weights.device,
|
||||||
)
|
)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
self.q_norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.k_norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.k_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
# class Qwen3MoeAttention(nn.Module):
|
self.max_past = (
|
||||||
# """Multi-headed attention from 'Attention Is All You Need' paper"""
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
# def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
# super().__init__()
|
self.kv_head_mapping = torch.arange(
|
||||||
# self.config = config
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
# self.layer_idx = layer_idx
|
).repeat_interleave(self.num_key_value_groups)
|
||||||
# self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
||||||
# self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
||||||
# self.scaling = self.head_dim**-0.5
|
|
||||||
# self.attention_dropout = config.attention_dropout
|
|
||||||
# self.is_causal = True
|
|
||||||
|
|
||||||
# self.q_proj = nn.Linear(
|
self.sliding_window = config.sliding_window
|
||||||
# config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
if not (
|
||||||
# )
|
self.config.use_sliding_window
|
||||||
# self.k_proj = nn.Linear(
|
and getattr(self.config, "sliding_window", None) is not None
|
||||||
# config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
and self.layer_idx >= self.config.max_window_layers
|
||||||
# )
|
):
|
||||||
# self.v_proj = nn.Linear(
|
self.sliding_window = None
|
||||||
# config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
||||||
# )
|
|
||||||
# self.o_proj = nn.Linear(
|
|
||||||
# config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
||||||
# )
|
|
||||||
# self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
||||||
# self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
||||||
# self.sliding_window = config.sliding_window
|
|
||||||
# if not (
|
|
||||||
# self.config.use_sliding_window
|
|
||||||
# and getattr(self.config, "sliding_window", None) is not None
|
|
||||||
# and self.layer_idx >= self.config.max_window_layers
|
|
||||||
# ):
|
|
||||||
# self.sliding_window = None
|
|
||||||
|
|
||||||
# def forward(
|
def forward(
|
||||||
# self,
|
self,
|
||||||
# hidden_states: torch.Tensor,
|
hidden_states,
|
||||||
# position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
cos,
|
||||||
# attention_mask: Optional[torch.Tensor],
|
sin,
|
||||||
# cache_position: Optional[torch.LongTensor] = None,
|
cu_seqlen_prefill,
|
||||||
# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
kv_cache,
|
||||||
# input_shape = hidden_states.shape[:-1]
|
slots,
|
||||||
# hidden_shape = (*input_shape, -1, self.head_dim)
|
seqlen,
|
||||||
|
hpu_attention_meta,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
# query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
|
||||||
# key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
|
||||||
# value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
||||||
|
|
||||||
# cos, sin = position_embeddings
|
self.rotary_emb(query_states, key_states, cos, sin)
|
||||||
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
# if past_key_value is not None:
|
kv_cache.store(
|
||||||
# # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
key=key_states,
|
||||||
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
value=value_states,
|
||||||
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
# attention_interface: Callable = eager_attention_forward
|
# Prefill
|
||||||
# if self.config._attn_implementation != "eager":
|
if cu_seqlen_prefill is not None:
|
||||||
# if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
# sdpa
|
||||||
# logger.warning_once(
|
attn_output = attention(
|
||||||
# "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
query=query_states,
|
||||||
# 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
key=key_states,
|
||||||
# )
|
value=value_states,
|
||||||
# else:
|
kv_cache=kv_cache,
|
||||||
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
softmax_scale=self.scaling,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
num_key_value_groups=self.num_key_value_groups,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query_states,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.scaling,
|
||||||
|
seqlen,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
|
)
|
||||||
|
|
||||||
# attn_output, attn_weights = attention_interface(
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
# self,
|
attn_output = self.o_proj(attn_output)
|
||||||
# query_states,
|
return attn_output
|
||||||
# key_states,
|
|
||||||
# value_states,
|
|
||||||
# attention_mask,
|
|
||||||
# dropout=0.0 if not self.training else self.attention_dropout,
|
|
||||||
# scaling=self.scaling,
|
|
||||||
# sliding_window=self.sliding_window, # diff with Llama
|
|
||||||
# **kwargs,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
||||||
# attn_output = self.o_proj(attn_output)
|
|
||||||
# return attn_output, attn_weights
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoE(nn.Module):
|
class Qwen3MoE(nn.Module):
|
||||||
@ -415,9 +384,21 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = Qwen3Attention(
|
if config.num_key_value_heads // weights.process_group.size() > 0:
|
||||||
config, prefix=f"{prefix}.self_attn", weights=weights, layer_idx=layer_idx
|
self.self_attn = Qwen3Attention(
|
||||||
)
|
config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = Qwen3MoeAttention(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
moe_layer_cls = (
|
moe_layer_cls = (
|
||||||
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user