diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 1d73dcb3..3bb7bdce 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -18,6 +18,20 @@ def fetch_from_cache(cache, blocks): return cache.index_select(0, blocks) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def attention( *, query: torch.Tensor, @@ -30,6 +44,7 @@ def attention( window_size_left: int = -1, causal: bool = True, softcap: Optional[float] = None, + num_key_value_groups: int = 1, ): fsdpa_op = ModuleFusedSDPA(FusedSDPA) bs = seqlen.input_lengths.shape[0] @@ -38,6 +53,8 @@ def attention( query = query.view(bs, -1, head_num, head_size).transpose(1, 2) key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2) value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2) + key = repeat_kv(key, num_key_value_groups) + value = repeat_kv(value, num_key_value_groups) attn_output = fsdpa_op( query, key, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 0c3af1ed..cce64196 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -26,7 +26,14 @@ from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, SpeculativeHead, + FastLinear, ) +from text_generation_server.utils.import_utils import ( + synchronize, + get_free_memory, +) +from loguru import logger +from text_generation_server.utils.log import log_master from text_generation_server.layers.layernorm import ( @@ -76,13 +83,57 @@ class Qwen3Attention(nn.Module): else: self.num_key_value_heads = config.num_key_value_heads - self.query_key_value = TensorParallelColumnLinear.load_multi( + self.query_proj = TensorParallelColumnLinear.load( config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, + prefix=f"{prefix}.q_proj", weights=weights, bias=False, ) + if self.num_key_value_heads != config.num_key_value_heads: + self.key_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.value_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + else: + self.key_proj = FastLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=False, + ) + self.value_proj = FastLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=False, + ) + # self.key_proj = TensorParallelColumnLinear.load( + # config, + # prefix=f"{prefix}.k_proj", + # weights=weights, + # bias=False, + # ) + # self.value_proj = TensorParallelColumnLinear.load( + # config, + # prefix=f"{prefix}.v_proj", + # weights=weights, + # bias=False, + # ) + # self.query_key_value = TensorParallelColumnLinear.load_multi( + # config, + # prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + # dim=0, + # weights=weights, + # bias=False, + # ) self.kv_scales = get_kv_scales(weights, f"{prefix}") @@ -131,25 +182,45 @@ class Qwen3Attention(nn.Module): seqlen, hpu_attention_meta, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + print(f"hidden_states shape: {hidden_states.shape}") input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - qkv = self.query_key_value(hidden_states) - print(f"qkv shape: {qkv.shape}") - print(f"self.head_dim: {self.head_dim}") - print(f"self.num_heads: {self.num_heads}") - print(f"self.num_key_value_heads: {self.num_key_value_heads}") - query_states, key_states, value_states = qkv.split( - [ - self.head_dim * self.num_heads, - self.head_dim * self.num_key_value_heads, - self.head_dim * self.num_key_value_heads, - ], - dim=1, + # qkv = self.query_key_value(hidden_states) + # print(f"qkv shape: {qkv.shape}") + # print(f"self.head_dim: {self.head_dim}") + # print(f"self.num_heads: {self.num_heads}") + # print(f"self.num_key_value_heads: {self.num_key_value_heads}") + # query_states, key_states, value_states = qkv.split( + # [ + # self.head_dim * self.num_heads, + # self.head_dim * self.num_key_value_heads, + # self.head_dim * self.num_key_value_heads, + # ], + # dim=1, + # ) + synchronize(hidden_states.device) + real_free_memory = get_free_memory(hidden_states.device, 1) + log_master( + logger.debug, + f"Attention forward1 Free memory real: {real_free_memory / 1e9:.2f}GB", ) + query_states = self.query_proj(hidden_states) + key_states = self.key_proj(hidden_states) + value_states = self.value_proj(hidden_states) query_states, _ = self.q_norm(query_states.view(hidden_shape)) key_states, _ = self.k_norm(key_states.view(hidden_shape)) value_states = value_states.view(hidden_shape) + print(f"query_states shape: {query_states.shape}") + print(f"key_states shape: {key_states.shape}") + print(f"value_states shape: {value_states.shape}") + synchronize(hidden_states.device) + real_free_memory = get_free_memory(hidden_states.device, 1) + log_master( + logger.debug, + f"Attention forward2 Free memory real: {real_free_memory / 1e9:.2f}GB", + ) + self.rotary_emb(query_states, key_states, cos, sin) kv_cache.store( @@ -171,6 +242,7 @@ class Qwen3Attention(nn.Module): seqlen=seqlen, softmax_scale=self.softmax_scale, window_size_left=self.max_past, + num_key_value_groups=self.num_key_value_groups, ) # Decode else: @@ -185,6 +257,7 @@ class Qwen3Attention(nn.Module): ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() + print(f"attn_output shape: {attn_output.shape}") return self.o_proj(attn_output) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py index 7174f46d..d04146f9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -17,6 +17,7 @@ from typing import List, Optional, Tuple, Type import torch from torch import nn +import torch.nn.functional as F from text_generation_server.layers.attention import ( Seqlen, HPUPagedAttentionMetadata, @@ -24,10 +25,17 @@ from text_generation_server.layers.attention import ( from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers import ( TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, SpeculativeHead, FastLinear, ) - +from text_generation_server.utils.import_utils import ( + synchronize, + get_free_memory, +) +from loguru import logger +from text_generation_server.utils.log import log_master from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -260,7 +268,19 @@ class Qwen3MoE(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) + # synchronize(x.device) + # real_free_memory = get_free_memory(x.device, 1) + # log_master( + # logger.debug, + # f"moe forward 1Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) out = self.moe(x, gating_output=router_logits) + # synchronize(x.device) + # real_free_memory = get_free_memory(x.device, 1) + # log_master( + # logger.debug, + # f"moe forward 2 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) # Reduce sum if self.process_group.size() > 1: @@ -270,7 +290,7 @@ class Qwen3MoE(nn.Module): class Qwen3MoeMLP(nn.Module): - def __init__(self, config, intermediate_size=None): + def __init__(self, prefix, config, weights, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -279,67 +299,104 @@ class Qwen3MoeMLP(nn.Module): if intermediate_size is not None else config.intermediate_size ) - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + gate_up_states = self.gate_up_proj(x) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) -# class Qwen3MoeSparseMoeBlock(nn.Module): -# def __init__(self, config): -# super().__init__() -# self.num_experts = config.num_experts -# self.top_k = config.num_experts_per_tok -# self.norm_topk_prob = config.norm_topk_prob +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob -# # gating -# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) -# self.experts = nn.ModuleList( -# [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] -# ) + # gating + # self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + self.experts = nn.ModuleList( + [ + Qwen3MoeMLP( + prefix=f"{prefix}.experts.{i}", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size, + ) + for i in range(self.num_experts) + ] + ) -# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: -# """ """ -# batch_size, sequence_length, hidden_dim = hidden_states.shape -# hidden_states = hidden_states.view(-1, hidden_dim) -# # router_logits: (batch * sequence_length, n_experts) -# router_logits = self.gate(hidden_states) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + input_shape = hidden_states.shape + _, hidden_dim = hidden_states.shape + # hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) -# routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) -# routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) -# if self.norm_topk_prob: # only diff with mixtral sparse moe block! -# routing_weights /= routing_weights.sum(dim=-1, keepdim=True) -# # we cast back to the input dtype -# routing_weights = routing_weights.to(hidden_states.dtype) + routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + print( + f"routing_weights: {routing_weights.device}, selected_experts: {selected_experts.device}" + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) -# final_hidden_states = torch.zeros( -# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device -# ) + final_hidden_states = torch.zeros( + (input_shape), dtype=hidden_states.dtype, device=hidden_states.device + ) -# # One hot encode the selected experts to create an expert mask -# # this will be used to easily index which expert is going to be sollicitated -# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + print(f"expert_mask: {expert_mask.device}") + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) -# # Loop over all available experts in the model and perform the computation on each expert -# for expert_idx in range(self.num_experts): -# expert_layer = self.experts[expert_idx] -# idx, top_x = torch.where(expert_mask[expert_idx]) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) -# # Index the correct hidden states and compute the expert hidden state for -# # the current expert. We need to make sure to multiply the output hidden -# # states by `routing_weights` on the corresponding tokens (top-1 and top-2) -# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) -# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - -# # However `index_add_` only support torch tensors for indexing so we'll use -# # the `top_x` tensor here. -# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) -# final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) -# return final_hidden_states, router_logits + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape(input_shape) + return final_hidden_states # @use_kernel_forward_from_hub("RMSNorm") @@ -383,6 +440,7 @@ class Qwen3MoeDecoderLayer(nn.Module): config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + # self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights) else: self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights) @@ -458,6 +516,11 @@ class Qwen3MoeModel(nn.Module): self.norm = FastRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) + synchronize(weights.device) + real_free_memory = get_free_memory(weights.device, 1) + log_master( + logger.debug, f"init model Free memory real: {real_free_memory / 1e9:.2f}GB" + ) def forward( self, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index b3a843dc..fabe9f22 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1420,7 +1420,7 @@ class FlashCausalLM(Model): raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = ( num_kv_heads // self.process_group.size() - if num_kv_heads > 1 + if num_kv_heads // self.process_group.size() > 0 else num_kv_heads ) assert self.num_kv_heads > 0