Fix num_key_value_heads issue

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-20 02:29:12 +00:00
parent b32b78e74e
commit 05b6ed1bff
4 changed files with 220 additions and 67 deletions

View File

@ -18,6 +18,20 @@ def fetch_from_cache(cache, blocks):
return cache.index_select(0, blocks) return cache.index_select(0, blocks)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def attention( def attention(
*, *,
query: torch.Tensor, query: torch.Tensor,
@ -30,6 +44,7 @@ def attention(
window_size_left: int = -1, window_size_left: int = -1,
causal: bool = True, causal: bool = True,
softcap: Optional[float] = None, softcap: Optional[float] = None,
num_key_value_groups: int = 1,
): ):
fsdpa_op = ModuleFusedSDPA(FusedSDPA) fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0] bs = seqlen.input_lengths.shape[0]
@ -38,6 +53,8 @@ def attention(
query = query.view(bs, -1, head_num, head_size).transpose(1, 2) query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
attn_output = fsdpa_op( attn_output = fsdpa_op(
query, query,
key, key,

View File

@ -26,7 +26,14 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
SpeculativeHead, SpeculativeHead,
FastLinear,
) )
from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from loguru import logger
from text_generation_server.utils.log import log_master
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -76,13 +83,57 @@ class Qwen3Attention(nn.Module):
else: else:
self.num_key_value_heads = config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_proj = TensorParallelColumnLinear.load(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefix=f"{prefix}.q_proj",
dim=0,
weights=weights, weights=weights,
bias=False, bias=False,
) )
if self.num_key_value_heads != config.num_key_value_heads:
self.key_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.value_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
else:
self.key_proj = FastLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.value_proj = FastLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
# self.key_proj = TensorParallelColumnLinear.load(
# config,
# prefix=f"{prefix}.k_proj",
# weights=weights,
# bias=False,
# )
# self.value_proj = TensorParallelColumnLinear.load(
# config,
# prefix=f"{prefix}.v_proj",
# weights=weights,
# bias=False,
# )
# self.query_key_value = TensorParallelColumnLinear.load_multi(
# config,
# prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
# dim=0,
# weights=weights,
# bias=False,
# )
self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_scales = get_kv_scales(weights, f"{prefix}")
@ -131,25 +182,45 @@ class Qwen3Attention(nn.Module):
seqlen, seqlen,
hpu_attention_meta, hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
print(f"hidden_states shape: {hidden_states.shape}")
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states) # qkv = self.query_key_value(hidden_states)
print(f"qkv shape: {qkv.shape}") # print(f"qkv shape: {qkv.shape}")
print(f"self.head_dim: {self.head_dim}") # print(f"self.head_dim: {self.head_dim}")
print(f"self.num_heads: {self.num_heads}") # print(f"self.num_heads: {self.num_heads}")
print(f"self.num_key_value_heads: {self.num_key_value_heads}") # print(f"self.num_key_value_heads: {self.num_key_value_heads}")
query_states, key_states, value_states = qkv.split( # query_states, key_states, value_states = qkv.split(
[ # [
self.head_dim * self.num_heads, # self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads, # self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads, # self.head_dim * self.num_key_value_heads,
], # ],
dim=1, # dim=1,
# )
synchronize(hidden_states.device)
real_free_memory = get_free_memory(hidden_states.device, 1)
log_master(
logger.debug,
f"Attention forward1 Free memory real: {real_free_memory / 1e9:.2f}GB",
) )
query_states = self.query_proj(hidden_states)
key_states = self.key_proj(hidden_states)
value_states = self.value_proj(hidden_states)
query_states, _ = self.q_norm(query_states.view(hidden_shape)) query_states, _ = self.q_norm(query_states.view(hidden_shape))
key_states, _ = self.k_norm(key_states.view(hidden_shape)) key_states, _ = self.k_norm(key_states.view(hidden_shape))
value_states = value_states.view(hidden_shape) value_states = value_states.view(hidden_shape)
print(f"query_states shape: {query_states.shape}")
print(f"key_states shape: {key_states.shape}")
print(f"value_states shape: {value_states.shape}")
synchronize(hidden_states.device)
real_free_memory = get_free_memory(hidden_states.device, 1)
log_master(
logger.debug,
f"Attention forward2 Free memory real: {real_free_memory / 1e9:.2f}GB",
)
self.rotary_emb(query_states, key_states, cos, sin) self.rotary_emb(query_states, key_states, cos, sin)
kv_cache.store( kv_cache.store(
@ -171,6 +242,7 @@ class Qwen3Attention(nn.Module):
seqlen=seqlen, seqlen=seqlen,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
num_key_value_groups=self.num_key_value_groups,
) )
# Decode # Decode
else: else:
@ -185,6 +257,7 @@ class Qwen3Attention(nn.Module):
) )
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
print(f"attn_output shape: {attn_output.shape}")
return self.o_proj(attn_output) return self.o_proj(attn_output)

View File

@ -17,6 +17,7 @@ from typing import List, Optional, Tuple, Type
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
@ -24,10 +25,17 @@ from text_generation_server.layers.attention import (
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
SpeculativeHead, SpeculativeHead,
FastLinear, FastLinear,
) )
from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from loguru import logger
from text_generation_server.utils.log import log_master
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -260,7 +268,19 @@ class Qwen3MoE(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(x) router_logits = self.gate(x)
# synchronize(x.device)
# real_free_memory = get_free_memory(x.device, 1)
# log_master(
# logger.debug,
# f"moe forward 1Free memory real: {real_free_memory / 1e9:.2f}GB"
# )
out = self.moe(x, gating_output=router_logits) out = self.moe(x, gating_output=router_logits)
# synchronize(x.device)
# real_free_memory = get_free_memory(x.device, 1)
# log_master(
# logger.debug,
# f"moe forward 2 Free memory real: {real_free_memory / 1e9:.2f}GB"
# )
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
@ -270,7 +290,7 @@ class Qwen3MoE(nn.Module):
class Qwen3MoeMLP(nn.Module): class Qwen3MoeMLP(nn.Module):
def __init__(self, config, intermediate_size=None): def __init__(self, prefix, config, weights, intermediate_size=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -279,67 +299,104 @@ class Qwen3MoeMLP(nn.Module):
if intermediate_size is not None if intermediate_size is not None
else config.intermediate_size else config.intermediate_size
) )
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # Fuse gate and up proj
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_up_proj = TensorParallelColumnLinear.load_multi(
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) gate_up_states = self.gate_up_proj(x)
return down_proj gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
# class Qwen3MoeSparseMoeBlock(nn.Module): class Qwen3MoeSparseMoeBlock(nn.Module):
# def __init__(self, config): def __init__(self, prefix, config, weights):
# super().__init__() super().__init__()
# self.num_experts = config.num_experts self.num_experts = config.num_experts
# self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
# self.norm_topk_prob = config.norm_topk_prob self.norm_topk_prob = config.norm_topk_prob
# # gating # gating
# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
# self.experts = nn.ModuleList( self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] self.experts = nn.ModuleList(
# ) [
Qwen3MoeMLP(
prefix=f"{prefix}.experts.{i}",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size,
)
for i in range(self.num_experts)
]
)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# """ """ """ """
# batch_size, sequence_length, hidden_dim = hidden_states.shape input_shape = hidden_states.shape
_, hidden_dim = hidden_states.shape
# hidden_states = hidden_states.view(-1, hidden_dim) # hidden_states = hidden_states.view(-1, hidden_dim)
# # router_logits: (batch * sequence_length, n_experts) # router_logits: (batch * sequence_length, n_experts)
# router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
# routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)
# routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights, selected_experts = torch.topk(
# if self.norm_topk_prob: # only diff with mixtral sparse moe block! routing_weights, self.top_k, dim=-1
# routing_weights /= routing_weights.sum(dim=-1, keepdim=True) )
# # we cast back to the input dtype print(
# routing_weights = routing_weights.to(hidden_states.dtype) f"routing_weights: {routing_weights.device}, selected_experts: {selected_experts.device}"
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# final_hidden_states = torch.zeros( final_hidden_states = torch.zeros(
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device (input_shape), dtype=hidden_states.dtype, device=hidden_states.device
# ) )
# # One hot encode the selected experts to create an expert mask # One hot encode the selected experts to create an expert mask
# # this will be used to easily index which expert is going to be sollicitated # this will be used to easily index which expert is going to be sollicitated
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
print(f"expert_mask: {expert_mask.device}")
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# # Loop over all available experts in the model and perform the computation on each expert # Index the correct hidden states and compute the expert hidden state for
# for expert_idx in range(self.num_experts): # the current expert. We need to make sure to multiply the output hidden
# expert_layer = self.experts[expert_idx] # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
# idx, top_x = torch.where(expert_mask[expert_idx]) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# # Index the correct hidden states and compute the expert hidden state for # However `index_add_` only support torch tensors for indexing so we'll use
# # the current expert. We need to make sure to multiply the output hidden # the `top_x` tensor here.
# # states by `routing_weights` on the corresponding tokens (top-1 and top-2) final_hidden_states.index_add_(
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) 0, top_x, current_hidden_states.to(hidden_states.dtype)
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] )
final_hidden_states = final_hidden_states.reshape(input_shape)
# # However `index_add_` only support torch tensors for indexing so we'll use return final_hidden_states
# # the `top_x` tensor here.
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
# return final_hidden_states, router_logits
# @use_kernel_forward_from_hub("RMSNorm") # @use_kernel_forward_from_hub("RMSNorm")
@ -383,6 +440,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
): ):
self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
# self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights)
else: else:
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
@ -458,6 +516,11 @@ class Qwen3MoeModel(nn.Module):
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
synchronize(weights.device)
real_free_memory = get_free_memory(weights.device, 1)
log_master(
logger.debug, f"init model Free memory real: {real_free_memory / 1e9:.2f}GB"
)
def forward( def forward(
self, self,

View File

@ -1420,7 +1420,7 @@ class FlashCausalLM(Model):
raise ValueError("Cannot get the number of key/value heads") raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = ( self.num_kv_heads = (
num_kv_heads // self.process_group.size() num_kv_heads // self.process_group.size()
if num_kv_heads > 1 if num_kv_heads // self.process_group.size() > 0
else num_kv_heads else num_kv_heads
) )
assert self.num_kv_heads > 0 assert self.num_kv_heads > 0