mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Fix num_key_value_heads issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
b32b78e74e
commit
05b6ed1bff
@ -18,6 +18,20 @@ def fetch_from_cache(cache, blocks):
|
||||
return cache.index_select(0, blocks)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def attention(
|
||||
*,
|
||||
query: torch.Tensor,
|
||||
@ -30,6 +44,7 @@ def attention(
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
num_key_value_groups: int = 1,
|
||||
):
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
bs = seqlen.input_lengths.shape[0]
|
||||
@ -38,6 +53,8 @@ def attention(
|
||||
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||
key = repeat_kv(key, num_key_value_groups)
|
||||
value = repeat_kv(value, num_key_value_groups)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
|
@ -26,7 +26,14 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
|
||||
from text_generation_server.layers.layernorm import (
|
||||
@ -76,13 +83,57 @@ class Qwen3Attention(nn.Module):
|
||||
else:
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
self.query_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
if self.num_key_value_heads != config.num_key_value_heads:
|
||||
self.key_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.value_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.v_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.key_proj = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.value_proj = FastLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.v_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
# self.key_proj = TensorParallelColumnLinear.load(
|
||||
# config,
|
||||
# prefix=f"{prefix}.k_proj",
|
||||
# weights=weights,
|
||||
# bias=False,
|
||||
# )
|
||||
# self.value_proj = TensorParallelColumnLinear.load(
|
||||
# config,
|
||||
# prefix=f"{prefix}.v_proj",
|
||||
# weights=weights,
|
||||
# bias=False,
|
||||
# )
|
||||
# self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
# config,
|
||||
# prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
# dim=0,
|
||||
# weights=weights,
|
||||
# bias=False,
|
||||
# )
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
@ -131,25 +182,45 @@ class Qwen3Attention(nn.Module):
|
||||
seqlen,
|
||||
hpu_attention_meta,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
print(f"hidden_states shape: {hidden_states.shape}")
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
print(f"qkv shape: {qkv.shape}")
|
||||
print(f"self.head_dim: {self.head_dim}")
|
||||
print(f"self.num_heads: {self.num_heads}")
|
||||
print(f"self.num_key_value_heads: {self.num_key_value_heads}")
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_dim * self.num_heads,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
# qkv = self.query_key_value(hidden_states)
|
||||
# print(f"qkv shape: {qkv.shape}")
|
||||
# print(f"self.head_dim: {self.head_dim}")
|
||||
# print(f"self.num_heads: {self.num_heads}")
|
||||
# print(f"self.num_key_value_heads: {self.num_key_value_heads}")
|
||||
# query_states, key_states, value_states = qkv.split(
|
||||
# [
|
||||
# self.head_dim * self.num_heads,
|
||||
# self.head_dim * self.num_key_value_heads,
|
||||
# self.head_dim * self.num_key_value_heads,
|
||||
# ],
|
||||
# dim=1,
|
||||
# )
|
||||
synchronize(hidden_states.device)
|
||||
real_free_memory = get_free_memory(hidden_states.device, 1)
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"Attention forward1 Free memory real: {real_free_memory / 1e9:.2f}GB",
|
||||
)
|
||||
|
||||
query_states = self.query_proj(hidden_states)
|
||||
key_states = self.key_proj(hidden_states)
|
||||
value_states = self.value_proj(hidden_states)
|
||||
query_states, _ = self.q_norm(query_states.view(hidden_shape))
|
||||
key_states, _ = self.k_norm(key_states.view(hidden_shape))
|
||||
value_states = value_states.view(hidden_shape)
|
||||
print(f"query_states shape: {query_states.shape}")
|
||||
print(f"key_states shape: {key_states.shape}")
|
||||
print(f"value_states shape: {value_states.shape}")
|
||||
synchronize(hidden_states.device)
|
||||
real_free_memory = get_free_memory(hidden_states.device, 1)
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"Attention forward2 Free memory real: {real_free_memory / 1e9:.2f}GB",
|
||||
)
|
||||
|
||||
self.rotary_emb(query_states, key_states, cos, sin)
|
||||
|
||||
kv_cache.store(
|
||||
@ -171,6 +242,7 @@ class Qwen3Attention(nn.Module):
|
||||
seqlen=seqlen,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
num_key_value_groups=self.num_key_value_groups,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -185,6 +257,7 @@ class Qwen3Attention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
print(f"attn_output shape: {attn_output.shape}")
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
@ -24,10 +25,17 @@ from text_generation_server.layers.attention import (
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -260,7 +268,19 @@ class Qwen3MoE(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
# synchronize(x.device)
|
||||
# real_free_memory = get_free_memory(x.device, 1)
|
||||
# log_master(
|
||||
# logger.debug,
|
||||
# f"moe forward 1Free memory real: {real_free_memory / 1e9:.2f}GB"
|
||||
# )
|
||||
out = self.moe(x, gating_output=router_logits)
|
||||
# synchronize(x.device)
|
||||
# real_free_memory = get_free_memory(x.device, 1)
|
||||
# log_master(
|
||||
# logger.debug,
|
||||
# f"moe forward 2 Free memory real: {real_free_memory / 1e9:.2f}GB"
|
||||
# )
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
@ -270,7 +290,7 @@ class Qwen3MoE(nn.Module):
|
||||
|
||||
|
||||
class Qwen3MoeMLP(nn.Module):
|
||||
def __init__(self, config, intermediate_size=None):
|
||||
def __init__(self, prefix, config, weights, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -279,67 +299,104 @@ class Qwen3MoeMLP(nn.Module):
|
||||
if intermediate_size is not None
|
||||
else config.intermediate_size
|
||||
)
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
gate_up_states = self.gate_up_proj(x)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
# class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
# def __init__(self, config):
|
||||
# super().__init__()
|
||||
# self.num_experts = config.num_experts
|
||||
# self.top_k = config.num_experts_per_tok
|
||||
# self.norm_topk_prob = config.norm_topk_prob
|
||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
# # gating
|
||||
# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
# self.experts = nn.ModuleList(
|
||||
# [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
|
||||
# )
|
||||
# gating
|
||||
# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Qwen3MoeMLP(
|
||||
prefix=f"{prefix}.experts.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
)
|
||||
for i in range(self.num_experts)
|
||||
]
|
||||
)
|
||||
|
||||
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# """ """
|
||||
# batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
# hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# # router_logits: (batch * sequence_length, n_experts)
|
||||
# router_logits = self.gate(hidden_states)
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
""" """
|
||||
input_shape = hidden_states.shape
|
||||
_, hidden_dim = hidden_states.shape
|
||||
# hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
# routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
# routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
# if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
||||
# routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# # we cast back to the input dtype
|
||||
# routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1
|
||||
)
|
||||
print(
|
||||
f"routing_weights: {routing_weights.device}, selected_experts: {selected_experts.device}"
|
||||
)
|
||||
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# final_hidden_states = torch.zeros(
|
||||
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
# )
|
||||
final_hidden_states = torch.zeros(
|
||||
(input_shape), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# # One hot encode the selected experts to create an expert mask
|
||||
# # this will be used to easily index which expert is going to be sollicitated
|
||||
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
).permute(2, 1, 0)
|
||||
print(f"expert_mask: {expert_mask.device}")
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
# # Loop over all available experts in the model and perform the computation on each expert
|
||||
# for expert_idx in range(self.num_experts):
|
||||
# expert_layer = self.experts[expert_idx]
|
||||
# idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = (
|
||||
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
)
|
||||
|
||||
# # Index the correct hidden states and compute the expert hidden state for
|
||||
# # the current expert. We need to make sure to multiply the output hidden
|
||||
# # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
|
||||
# # However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# # the `top_x` tensor here.
|
||||
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
# final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
# return final_hidden_states, router_logits
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(
|
||||
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
||||
)
|
||||
final_hidden_states = final_hidden_states.reshape(input_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# @use_kernel_forward_from_hub("RMSNorm")
|
||||
@ -383,6 +440,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||
# self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights)
|
||||
|
||||
else:
|
||||
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
@ -458,6 +516,11 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
synchronize(weights.device)
|
||||
real_free_memory = get_free_memory(weights.device, 1)
|
||||
log_master(
|
||||
logger.debug, f"init model Free memory real: {real_free_memory / 1e9:.2f}GB"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1420,7 +1420,7 @@ class FlashCausalLM(Model):
|
||||
raise ValueError("Cannot get the number of key/value heads")
|
||||
self.num_kv_heads = (
|
||||
num_kv_heads // self.process_group.size()
|
||||
if num_kv_heads > 1
|
||||
if num_kv_heads // self.process_group.size() > 0
|
||||
else num_kv_heads
|
||||
)
|
||||
assert self.num_kv_heads > 0
|
||||
|
Loading…
Reference in New Issue
Block a user