From a3967a57bc406de9ca636243f23d73741baf0ee0 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Thu, 8 May 2025 03:12:22 +0000 Subject: [PATCH] Fix experts issue Signed-off-by: yuanwu --- .../custom_modeling/flash_llama4_modeling.py | 881 ++++++++++++++---- .../custom_modeling/flash_llama_modeling.py | 21 +- 2 files changed, 706 insertions(+), 196 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 2a74d7e5..de53350c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Type import torch import math @@ -70,15 +70,40 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE _CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" _CONFIG_FOR_DOC = "Llama4Config" - +def print_0(*args, **kwargs): + """ + Only print on rank 0 in distributed training. + Works like built-in print() function but only executes on rank 0. + """ + # 检查是否处于分布式环境 + if torch.distributed.is_initialized(): + # 获取当前rank + if torch.distributed.get_rank() == 0: + print(*args, **kwargs) + else: + # 如果不是分布式环境,正常打印 + print(*args, **kwargs) + def torch_save(tensor, name): # Only save on the main process (rank 0) when using distributed training if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: torch.save(tensor, name) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + print_0(f"batch={batch}, num_key_value_heads={num_key_value_heads}, slen={slen}, head_dim={head_dim}") + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + class Llama4TextExperts(nn.Module): - def __init__(self, prefix, config: Llama4TextConfig, weights): + def __init__(self, prefix, config: Llama4TextConfig, weights, layer_idx): super().__init__() self.process_group = weights.process_group self.num_experts = config.num_local_experts @@ -102,10 +127,10 @@ class Llama4TextExperts(nn.Module): f"textExperts2 Free memory real: {real_free_memory / 1e9:.2f}GB" ) - + self.layer_idx = layer_idx self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, run_index) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. @@ -119,24 +144,44 @@ class Llama4TextExperts(nn.Module): torch.Tensor """ gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) + if run_index != -1: + torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") + + down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) + if run_index != -1: + torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") + + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") + + gate_up = torch.bmm(hidden_states, gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + if run_index != -1: + torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") + torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") + + next_states = torch.bmm((up * self.act_fn(gate)), down_proj) - + + + next_states = next_states.view(-1, self.hidden_size) + if run_index != -1: + torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") + # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(next_states, group=self.process_group) - - next_states = next_states.view(-1, self.hidden_size) - + return next_states # Phi3MLP class Llama4TextMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_idx): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -156,16 +201,43 @@ class Llama4TextMLP(nn.Module): weights=weights, bias=False, ) + self.layer_idx = layer_idx self.act_fn = ACT2FN[config.hidden_act] + # self.intermediate_size = ( + # config.intermediate_size // weights.process_group.size() + # ) - def forward(self, x): + # self.config = config + # # self.gate_up_proj = TensorParallelColumnLinear.load_multi( + # # config, + # # prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + # # weights=weights, + # # dim=0, + # # bias=False, + # # ) + # self.gate_proj = TensorParallelColumnLinear.load(config=config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False) + # self.up_proj = TensorParallelColumnLinear.load(config=config, prefix=f"{prefix}.up_proj", weights=weights, bias=False) + # self.down_proj = TensorParallelRowLinear.load(config=config, prefix=f"{prefix}.down_proj", weights=weights, bias=False) + # self.activation_fn = ACT2FN[config.hidden_act] + + + + def forward(self, x, run_index, reduce=True): shape = x.shape + # gate_up_states = self.gate_up_proj(x) + # gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) + # result = self.down_proj( + # self.activation_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + # ) + # return result + # down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) + # return self.down_proj(down_proj) gate_up_states = self.gate_up_proj(x) - gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) - result = self.down_proj( - self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=True ) - return result + @@ -205,12 +277,32 @@ class Llama4TextRMSNorm(nn.Module): class Llama4TextMoe(nn.Module): - def __init__(self, prefix, config, weights, layer_idx): + def __init__( + self, + prefix, + config, + weights, + layer_idx, + moe_layer_cls: Type[MoELayer], + ): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - + log_master(logger.debug, f"weights.load: {weights.loader}") + # self.experts = moe_layer_cls( + # prefix=f"{prefix}.experts", + # n_experts=config.num_local_experts, + # n_expert_group=None, + # renormalize=True, + # topk=config.num_experts_per_tok, + # topk_group=None, + # weights=weights, + # scoring_func="sigmoid", + # ) + # assert isinstance(self.experts, MoELayer) + + self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) @@ -227,7 +319,7 @@ class Llama4TextMoe(nn.Module): logger.debug, f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" ) - self.shared_expert = LlamaMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, index=layer_idx) + self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) log_master( @@ -235,13 +327,15 @@ class Llama4TextMoe(nn.Module): f"TextMode3 Free memory real: {real_free_memory / 1e9:.2f}GB" ) self.process_group = weights.process_group + self.layer_idx = layer_idx - - def forward(self, hidden_states, adapter_data): - #seq_len, hidden_dim = hidden_states.shape + def forward(self, hidden_states, adapter_data, run_index): + seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) tokens_per_expert = hidden_states.shape[0] router_logits = self.router(hidden_states) + if run_index != -1: + torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt") router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( @@ -253,6 +347,9 @@ class Llama4TextMoe(nn.Module): torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + if run_index != -1: + torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt") + router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) routed_in = torch.gather( @@ -260,23 +357,59 @@ class Llama4TextMoe(nn.Module): dim=0, index=router_indices, ).to(hidden_states.device) + if run_index != -1: + torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt") + + # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) - routed_out = self.experts(routed_in) - out = self.shared_expert(hidden_states, adapter_data) + if run_index != -1: + torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt") + routed_out = self.experts(routed_in, run_index) + if run_index != -1: + torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt") + out = self.shared_expert(hidden_states, run_index, reduce=False) + if run_index != -1: + torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt") # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)) + # if run_index != -1: + # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") + #Reduce sum + # if self.process_group.size() > 1: + # torch.distributed.all_reduce(out, group=self.process_group) + if run_index != -1: + torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") + return out + + # shared_output = self.shared_expert(hidden_states, reduce=False) + + # router_logits = self.router(hidden_states) + + # out = self.experts(hidden_states, gating_output=router_logits) + + # if shared_output is not None: + # out = out + shared_output + + # # Reduce sum + # if self.process_group.size() > 1: + # torch.distributed.all_reduce(out, group=self.process_group) + + # return out.view(*hidden_states.shape) class Llama4TextRotaryEmbedding(nn.Module): - def __init__(self, config: 'Llama4TextConfig', device=None): + def __init__(self, config: Llama4TextConfig, device=None): super().__init__() + # BC: "rope_type" was originally "type" self.rope_type = "llama3" if config.rope_scaling is not None else "default" + self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings + self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -285,136 +418,207 @@ class Llama4TextRotaryEmbedding(nn.Module): self.original_inv_freq = self.inv_freq @torch.no_grad() - @dynamic_rope_update - def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: - """ - Args: - x: Input tensor of shape [batch, seq_len, heads, dim] - position_ids: Position indices of shape [batch, seq_len] - Returns: - Rotary embeddings as float tensors [batch, seq_len, heads, dim] - """ - # Expand inv_freq and position_ids for broadcasting + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - - # Compute frequencies (replaces complex phase) - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) # [batch, seq_len, dim//2] - - # Generate cos/sin components directly (replaces torch.polar) - cos_vals = torch.cos(freqs) * self.attention_scaling - sin_vals = torch.sin(freqs) * self.attention_scaling - - # Interleave cos/sin values to match original complex format - dim = x.size(-1) - if dim % 2 != 0: - raise ValueError(f"Feature dimension {dim} must be even for Rotary Embedding") - - # Stack and reshape to [batch, seq_len, dim] format - freqs_cis = torch.stack([cos_vals, sin_vals], dim=-1) # [batch, seq_len, dim//2, 2] - freqs_cis = freqs_cis.reshape(*freqs_cis.shape[:-2], dim) # [batch, seq_len, dim] - + origin_device = x.device + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" and x.device.type != "hpu" else "cpu" + inv_freq_expanded = inv_freq_expanded.to(device_type) + position_ids_expanded = position_ids_expanded.to(device_type) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + freqs_cis = freqs_cis * self.attention_scaling return freqs_cis - + + def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: torch.Tensor, # Should be [cosθ, sinθ] instead of complex numbers + freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embedding to query and key tensors using floating-point operations only. - - Args: - xq: Query tensor of shape (batch, seq_len, n_heads, head_dim) - xk: Key tensor of shape (batch, seq_len, n_heads, head_dim) - freqs_cis: Precomputed rotation frequencies as [cosθ, sinθ] - of shape (batch, seq_len, head_dim//2, 2) - Returns: - Rotated query and key tensors with same shape as input - """ - # Verify head_dim is even - assert xq.size(-1) % 2 == 0, "Feature dimension must be even for rotary embedding" - - # Reshape to separate real and imaginary components (pairs of adjacent elements) - xq_reshaped = xq.float().reshape(*xq.shape[:-1], -1, 2) # [..., head_dim//2, 2] - xk_reshaped = xk.float().reshape(*xk.shape[:-1], -1, 2) # [..., head_dim//2, 2] - - # Extract cosθ and sinθ (assuming freqs_cis is already in [cosθ, sinθ] format) - cos_theta = freqs_cis[..., 0] # [batch, seq_len, head_dim//2] - sin_theta = freqs_cis[..., 1] # [batch, seq_len, head_dim//2] - - # Expand dimensions for broadcasting [batch, seq_len, n_heads, head_dim//2] - cos_theta = cos_theta.unsqueeze(2) # Add n_heads dimension - sin_theta = sin_theta.unsqueeze(2) - - # Rotary transformation (mathematically equivalent to complex multiplication) - # xq_rotated = [xq0*cosθ - xq1*sinθ, xq0*sinθ + xq1*cosθ] - xq_out = torch.stack([ - xq_reshaped[..., 0] * cos_theta - xq_reshaped[..., 1] * sin_theta, - xq_reshaped[..., 0] * sin_theta + xq_reshaped[..., 1] * cos_theta - ], dim=-1) - - xk_out = torch.stack([ - xk_reshaped[..., 0] * cos_theta - xk_reshaped[..., 1] * sin_theta, - xk_reshaped[..., 0] * sin_theta + xk_reshaped[..., 1] * cos_theta - ], dim=-1) - - # Restore original shape - xq_out = xq_out.flatten(-2) # [batch, seq_len, n_heads, head_dim] - xk_out = xk_out.flatten(-2) - - # Maintain original dtype + orig_device= xq.device + xq = xq.to("cpu") + xk = xk.to("cpu") + xq = xq.view(freqs_cis.shape[0], -1, *xq.shape[-2:]) + xk = xk.view(freqs_cis.shape[0], -1, *xk.shape[-2:]) + #log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + #log_master(logger.debug, f"xq_: {xq_.shape}, xk_: {xk_.shape}") + #log_master(logger.debug, f"freqs_cis: {freqs_cis.shape}") + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + xq_out = xq_out.view(-1, *xq_out.shape[-2:]).to(orig_device) + xk_out = xk_out.view(-1, *xk_out.shape[-2:]).to(orig_device) + xq = xq.to(orig_device) + xk = xk.to(orig_device) return xq_out.type_as(xq), xk_out.type_as(xk) + +# class Llama4TextRotaryEmbedding(nn.Module): +# def __init__(self, config: 'Llama4TextConfig', device=None): +# super().__init__() +# self.rope_type = "llama3" if config.rope_scaling is not None else "default" +# self.max_seq_len_cached = config.max_position_embeddings +# self.original_max_seq_len = config.max_position_embeddings +# self.config = config +# self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + +# inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) +# self.register_buffer("inv_freq", inv_freq, persistent=False) +# self.original_inv_freq = self.inv_freq + +# @torch.no_grad() +# @dynamic_rope_update +# def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: +# """ +# Args: +# x: Input tensor of shape [batch, seq_len, heads, dim] +# position_ids: Position indices of shape [batch, seq_len] +# Returns: +# Rotary embeddings as float tensors [batch, seq_len, heads, dim] +# """ +# # Expand inv_freq and position_ids for broadcasting +# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) +# position_ids_expanded = position_ids[:, None, :].float() + +# # Compute frequencies (replaces complex phase) +# freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) # [batch, seq_len, dim//2] + +# # Generate cos/sin components directly (replaces torch.polar) +# cos_vals = torch.cos(freqs) * self.attention_scaling +# sin_vals = torch.sin(freqs) * self.attention_scaling + +# # Interleave cos/sin values to match original complex format +# dim = x.size(-1) +# if dim % 2 != 0: +# raise ValueError(f"Feature dimension {dim} must be even for Rotary Embedding") + +# # Stack and reshape to [batch, seq_len, dim] format +# freqs_cis = torch.stack([cos_vals, sin_vals], dim=-1) # [batch, seq_len, dim//2, 2] +# freqs_cis = freqs_cis.reshape(*freqs_cis.shape[:-2], dim) # [batch, seq_len, dim] + +# return freqs_cis + +# def apply_rotary_emb( +# xq: torch.Tensor, +# xk: torch.Tensor, +# freqs_cis: torch.Tensor, # Should be [cosθ, sinθ] instead of complex numbers +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# """ +# Apply rotary position embedding to query and key tensors using floating-point operations only. + +# Args: +# xq: Query tensor of shape (batch, seq_len, n_heads, head_dim) +# xk: Key tensor of shape (batch, seq_len, n_heads, head_dim) +# freqs_cis: Precomputed rotation frequencies as [cosθ, sinθ] +# of shape (batch, seq_len, head_dim//2, 2) +# Returns: +# Rotated query and key tensors with same shape as input +# """ +# # Verify head_dim is even +# assert xq.size(-1) % 2 == 0, "Feature dimension must be even for rotary embedding" + +# # Reshape to separate real and imaginary components (pairs of adjacent elements) +# xq_reshaped = xq.float().reshape(*xq.shape[:-1], -1, 2) # [..., head_dim//2, 2] +# xk_reshaped = xk.float().reshape(*xk.shape[:-1], -1, 2) # [..., head_dim//2, 2] + +# # Extract cosθ and sinθ (assuming freqs_cis is already in [cosθ, sinθ] format) +# cos_theta = freqs_cis[..., 0] # [batch, seq_len, head_dim//2] +# sin_theta = freqs_cis[..., 1] # [batch, seq_len, head_dim//2] + +# # Expand dimensions for broadcasting [batch, seq_len, n_heads, head_dim//2] +# cos_theta = cos_theta.unsqueeze(2) # Add n_heads dimension +# sin_theta = sin_theta.unsqueeze(2) + +# # Rotary transformation (mathematically equivalent to complex multiplication) +# # xq_rotated = [xq0*cosθ - xq1*sinθ, xq0*sinθ + xq1*cosθ] +# xq_out = torch.stack([ +# xq_reshaped[..., 0] * cos_theta - xq_reshaped[..., 1] * sin_theta, +# xq_reshaped[..., 0] * sin_theta + xq_reshaped[..., 1] * cos_theta +# ], dim=-1) + +# xk_out = torch.stack([ +# xk_reshaped[..., 0] * cos_theta - xk_reshaped[..., 1] * sin_theta, +# xk_reshaped[..., 0] * sin_theta + xk_reshaped[..., 1] * cos_theta +# ], dim=-1) + +# # Restore original shape +# xq_out = xq_out.flatten(-2) # [batch, seq_len, n_heads, head_dim] +# xk_out = xk_out.flatten(-2) + +# # Maintain original dtype +# return xq_out.type_as(xq), xk_out.type_as(xk) + class Llama4TextAttention(FlashLlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, prefix, config, weights, layer_idx): super().__init__(layer_idx, prefix, config, weights) self.config = config - # self.layer_idx = layer_idx - #self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - # self.num_attention_heads = config.num_attention_heads - # self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - # self.num_key_value_heads = config.num_key_value_heads - # self.scaling = self.head_dim**-0.5 - # self.attn_scale = config.attn_scale - # self.floor_scale = config.floor_scale - # self.attn_temperature_tuning = config.attn_temperature_tuning - # self.attention_dropout = config.attention_dropout - # self.is_causal = True + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - # # `config.attention_multiplier` is used in Granite - # self.softmax_scale = getattr( - # config, "attention_multiplier", self.head_dim**-0.5 - # ) + # `config.attention_multiplier` is used in Granite + self.softmax_scale = getattr( + config, "attention_multiplier", self.head_dim**-0.5 + ) - # if self.num_attention_heads % weights.process_group.size() != 0: - # raise ValueError( - # f"`num_attention_heads` must be divisible by `num_shards` (got `num_attention_heads`: {self.num_attention_heads} " - # f"and `num_shards`: {weights.process_group.size()}" - # ) - # if config.num_key_value_heads % weights.process_group.size() != 0: - # raise ValueError( - # f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " - # f"and `num_shards`: {weights.process_group.size()}" - # ) - # self.num_heads = self.num_attention_heads // weights.process_group.size() - # self.num_key_value_heads = ( - # config.num_key_value_heads // weights.process_group.size() - # ) + if self.num_attention_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_attention_heads` must be divisible by `num_shards` (got `num_attention_heads`: {self.num_attention_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + if config.num_key_value_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_attention_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) - # self.query_key_value = load_attention(config, prefix, weights, layer_idx) + #self.query_key_value = load_attention(config, prefix, weights, layer_idx) - # self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.q_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=getattr(config, "attention_bias", False), + ) + self.k_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=getattr(config, "attention_bias", False), + ) + self.v_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=getattr(config, "attention_bias", False), + ) - # o_proj = TensorParallelRowLinear.load( - # config, - # prefix=f"{prefix}.o_proj", - # weights=weights, - # bias=getattr(config, "attention_bias", False), - # ) + self.o_proj = TensorParallelRowLinear.load( + config=config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=getattr(config, "attention_bias", False), + ) # self.o_proj = TensorParallelAdapterRowLinear.load( # o_proj, @@ -423,10 +627,10 @@ class Llama4TextAttention(FlashLlamaAttention): # process_group=weights.process_group, # ) - # self.num_groups = self.num_heads // self.num_key_value_heads - # self.kv_head_mapping = torch.arange( - # 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - # ).repeat_interleave(self.num_groups) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) # self.q_proj = nn.Linear( @@ -447,19 +651,21 @@ class Llama4TextAttention(FlashLlamaAttention): def forward( self, hidden_states: torch.Tensor, - cos, - sin, + freqs_ci, cu_seqlen_prefill, kv_cache: KVCache, slots, seqlen, adapter_data, run_index, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bs = seqlen.input_lengths.shape[0] input_shape = hidden_states.shape[:-1] - #hidden_shape = (*input_shape, -1, self.head_dim) - qkv = self.query_key_value(hidden_states, adapter_data) + hidden_shape = (*input_shape, -1, self.head_dim) + #qkv = self.query_key_value(hidden_states, adapter_data) # query_states, kv_states = qkv.split( # [ # self.head_size * self.num_heads, @@ -467,31 +673,35 @@ class Llama4TextAttention(FlashLlamaAttention): # ], # dim=-1, # ) - query_states, key_states, value_states = qkv.split( - [ - self.head_size * self.num_heads, - self.head_size * self.num_key_value_heads, - self.head_size * self.num_key_value_heads, - ], - dim=-1, - ) + # query_states, key_states, value_states = qkv.split( + # [ + # self.head_size * self.num_heads, + # self.head_size * self.num_key_value_heads, + # self.head_size * self.num_key_value_heads, + # ], + # dim=-1, + # ) + query_states = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) - query_states = query_states.view(-1, self.num_heads, self.head_size) - key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + # query_states = query_states.view(-1, self.num_heads, self.head_size) + # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) if run_index != -1: torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") - # query_states = self.q_proj(hidden_states).view(hidden_shape) - # key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) - # value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if self.use_rope: # the 16E model skips rope for long context on certain layers #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) - self.rotary_emb(query_states, key_states, cos, sin) + #self.rotary_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_emb( + query_states, key_states, freqs_ci + ) + + if run_index != -1: torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") @@ -507,6 +717,9 @@ class Llama4TextAttention(FlashLlamaAttention): torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") + # query_states = query_states.view(-1, self.num_heads, self.head_size) + # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) + # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) # query_states = query_states.transpose(1, 2) # key_states = key_states.transpose(1, 2) @@ -516,29 +729,72 @@ class Llama4TextAttention(FlashLlamaAttention): slots=slots, kv_scales=self.kv_scales, ) + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + #indice = torch.tensor([0]).to(query_states.device) + #cache_position = position_ids + #log_master(logger.debug, f"cache_position: {cache_position.shape}") + + + attn_scales = ( + torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 + ) + #seq_len = input_shape / bs + attn_scales = attn_scales.view(*input_shape, 1, 1) + query_states = (query_states * attn_scales).to(query_states.dtype) + if run_index != -1: + torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt") + torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt") + # Prefill if cu_seqlen_prefill is not None: - log_master( - logger.debug, - f"Prefill: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}" - ) # sdpa - attn_output = attention( - query=query_states, - key=key_states, - value=value_states, - kv_scales=self.kv_scales, - kv_cache=kv_cache, - seqlen=seqlen, - softmax_scale=self.softmax_scale, + # log_master(logger.debug, f"self.softmax_scale: {self.softmax_scale}") + # attn_output = attention( + # query=query_states, + # key=key_states, + # value=value_states, + # kv_scales=self.kv_scales, + # kv_cache=kv_cache, + # seqlen=seqlen, + # softmax_scale=self.softmax_scale, + # causal=True + # ) + query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) + key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + print_0(f"self.num_key_value_groups={self.num_key_value_groups}") + key = repeat_kv(key, self.num_key_value_groups) + value = repeat_kv(value, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and causal_mask.ndim == 4: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + is_causal = query.shape[2] > 1 and causal_mask is None + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + print_0(f"query.shape={query.shape}, query={query}") + print_0(f"key.shape={key.shape}, key={key}") + print_0(f"value.shape={value.shape}, value={value}") + print_0(f"attention_mask.shape={causal_mask.shape}, attention_mask={causal_mask}") + print_0(f"scaling={self.scaling}, is_causal={is_causal}") + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=0, + scale=self.scaling, + is_causal=is_causal, ) + attn_output = attn_output.transpose(1, 2).contiguous() # Decode else: - log_master( - logger.debug, - f"Decode: {cu_seqlen_prefill} {seqlen} {slots} {self.kv_head_mapping}" - ) attn_output = paged_attention( query_states, kv_cache, @@ -549,9 +805,12 @@ class Llama4TextAttention(FlashLlamaAttention): hpu_attention_meta=hpu_attention_meta, ) - return self.o_proj( - attn_output.view(-1, self.num_heads * self.head_size), adapter_data - ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + if run_index != -1: + torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt") + attn_output = self.o_proj(attn_output) + return attn_output + class Llama4TextDecoderLayer(nn.Module): def __init__(self, prefix, config, weights, layer_idx): @@ -569,8 +828,16 @@ class Llama4TextDecoderLayer(nn.Module): self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.is_moe_layer = layer_idx in config.moe_layers + log_master(logger.debug, f"self.is_moe_layer: {self.is_moe_layer}, layer_idx:{layer_idx}") + log_master(logger.debug, f"moe_layers:{config.moe_layers}") if self.is_moe_layer: # the 128E model interleaves dense / sparse - self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights, layer_idx) + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + + self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights, layer_idx, moe_layer_cls) else: self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) @@ -593,16 +860,17 @@ class Llama4TextDecoderLayer(nn.Module): def forward( self, hidden_states, - residual, - cos, - sin, + freqs_ci, cu_seqlen_prefill, kv_cache, slots, seqlen, adapter_data, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - run_index + attention_mask: Optional[torch.Tensor] = None, + chunk_causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + run_index: int = 0, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states if run_index != -1: @@ -611,16 +879,21 @@ class Llama4TextDecoderLayer(nn.Module): if run_index != -1: torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") + # use local attention mask for ROPE layers + if self.use_chunked_attention and chunk_causal_mask is not None: + attention_mask = chunk_causal_mask + attention_states = self.self_attn( hidden_states, - cos, - sin, + freqs_ci, cu_seqlen_prefill, kv_cache, slots, seqlen, adapter_data, run_index, + attention_mask=attention_mask, + position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) if run_index != -1: @@ -635,7 +908,7 @@ class Llama4TextDecoderLayer(nn.Module): hidden_states = self.post_attention_layernorm(hidden_states) if run_index != -1: torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") - hidden_states = self.feed_forward(hidden_states, adapter_data) + hidden_states = self.feed_forward(hidden_states, adapter_data, run_index) if run_index != -1: torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") hidden_states = residual + hidden_states.view(residual.shape) @@ -698,7 +971,7 @@ class Llama4TextModel(nn.Module): ) self.run_index = -1 - #self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) self.gradient_checkpointing = False def forward( @@ -711,6 +984,7 @@ class Llama4TextModel(nn.Module): seqlen: Seqlen, adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -719,20 +993,39 @@ class Llama4TextModel(nn.Module): log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) + #cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) + log_master(logger.debug, f"position_ids.shape={position_ids.shape}, position_ids={position_ids}") + bs = seqlen.input_lengths.shape[0] + seq_len = inputs_embeds.shape[0] / bs + cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) - residual = None + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + log_master(logger.debug, f"cache_position={cache_position}") + log_master(logger.debug, f"position_ids={position_ids}") + causal_mask, chunk_causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False + ) + log_master(logger.debug, f"causal_mask={causal_mask}") + log_master(logger.debug, f"causal_mask={causal_mask.shape}") + log_master(logger.debug, f"chunk_causal_mask={chunk_causal_mask}") + + + + + freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, - residual, - cos, - sin, + freqs_ci, cu_seqlen_prefill, kv_cache[i], slots, seqlen, adapter_data, + attention_mask=causal_mask, + chunk_causal_mask=chunk_causal_mask, + position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, run_index=self.run_index, ) @@ -747,6 +1040,198 @@ class Llama4TextModel(nn.Module): self.run_index += 1 return hidden_states + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + chunked_attention_mask=None, + use_cache=True, + ): + print(f"update 11111111111111111") + print(f"self.config._attn_implementation={self.config._attn_implementation}") + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask, attention_mask # flash does not support chunked attn TODO support flash + return None, None + + if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: + return None, None + + print(f"update 222222222222222222") + sequence_length = input_tensor.shape[1] + attention_chunk_size = self.config.attention_chunk_size + + first_cache_position = cache_position[0] + + if past_key_values is not None: + full_cache_length = past_key_values.get_max_cache_shape() or sequence_length + else: + full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length + + cond1 = first_cache_position >= attention_chunk_size + cond2 = (first_cache_position < attention_chunk_size) & ( + first_cache_position + sequence_length > attention_chunk_size + ) + key_length = ( + torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + ) + if use_cache + else full_cache_length + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + dtype, device = input_tensor.dtype, input_tensor.device + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=max(full_cache_length, attention_chunk_size), + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + device=device + ) + if full_cache_length > self.config.attention_chunk_size: + start_idx = max(first_cache_position - attention_chunk_size + 1, 0) + end_idx = start_idx + key_length + chunked_attention_mask = self.create_chunked_attention_mask( + self.config.attention_chunk_size, + start=start_idx, # same offset as with flex + end=end_idx, + device=device, + ) + + local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well + # It may be smaller than attention_chunk_size -> pad it + requires_padding = local_attention_mask.shape[-1] < attention_chunk_size + if requires_padding: + local_attention_mask = nn.functional.pad( + local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) + ) + # Depending on the padding, take the query tokens from the end or the cache_position + if not requires_padding: + chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] + else: + chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] + + chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1) + chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] + if self.config._attn_implementation == "eager": + min_dtype = torch.finfo(dtype).min + chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and attention_mask.ndim == 4 + and not output_attentions # Only unmask for 4d masks + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: + chunked_attention_mask = chunked_attention_mask.bool() + causal_mask = causal_mask.bool() + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=first_cache_position, + is_training=self.training, + ): + causal_mask = None + return causal_mask, chunked_attention_mask + + def create_chunked_attention_mask( + self, attention_chunk_size: int, start: int, end: int, device: torch.device + ) -> torch.Tensor: + """ + Generate the following: + + 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | + '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | + '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | + 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | + '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | + '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | + + If the chunk size is 3. + This can just be applied over the already created attention mask + """ + arange_vector = torch.arange(start, end, device=device) + block_pos = torch.abs( + arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size + ) + token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) + mask = (block_pos == 0) & (token_pos <= 0) + return mask.to(device) + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + class Llama4ForCausalLM(nn.Module): def __init__(self, prefix, config, weights): @@ -772,6 +1257,7 @@ class Llama4ForCausalLM(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -784,6 +1270,7 @@ class Llama4ForCausalLM(nn.Module): seqlen, adapter_data=adapter_data, hpu_attention_meta=hpu_attention_meta, + attention_mask=attention_mask, ) print(f"lm_head_indices={lm_head_indices}") if lm_head_indices is not None: @@ -1457,7 +1944,15 @@ class Llama4ForConditionalGeneration(nn.Module): log_master( logger.debug, f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" - ) + ) + + def _get_padding_mask(input_ids, pad_token_id=0): + return (input_ids != pad_token_id).long() # 非填充位置为1,填充位置为0 + + # 示例 + attention_mask = _get_padding_mask(input_ids) + attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1) + log_master(logger.debug,f"attention_mask={attention_mask}") inputs_embeds = self.text_model.model.embed_tokens(input_ids) vision_feature_layer = ( vision_feature_layer @@ -1478,7 +1973,6 @@ class Llama4ForConditionalGeneration(nn.Module): # "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" # ) if pixel_values is not None: - print(f"pixel_values!!!!!!!!!!!!!!!!!") image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, @@ -1517,6 +2011,7 @@ class Llama4ForConditionalGeneration(nn.Module): hpu_attention_meta, adapter_data, lm_head_indices, + attention_mask ) return logits, speculative_logits \ No newline at end of file diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1b7e1052..7a6e561f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -61,6 +61,11 @@ from text_generation_server.utils.weights import ( ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +def torch_save(tensor, name): + # Only save on the main process (rank 0) when using distributed training + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + torch.save(tensor, name) + def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. @@ -377,7 +382,7 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - + self.index = index with no_fp8(weights): self.self_attn = FlashLlamaAttention( index=index, @@ -438,6 +443,7 @@ class FlashLlamaLayer(nn.Module): seqlen, adapter_data, cross_attention_states, + run_index, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -454,6 +460,10 @@ class FlashLlamaLayer(nn.Module): adapter_data, hpu_attention_meta=hpu_attention_meta, ) + + if run_index != -1: + torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt") + if self.residual_multiplier is not None: attn_output *= self.residual_multiplier @@ -462,6 +472,10 @@ class FlashLlamaLayer(nn.Module): ) mlp_output = self.mlp(normed_attn_res_output, adapter_data) + if run_index != -1: + torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt") + + if self.residual_multiplier is not None: mlp_output *= self.residual_multiplier @@ -471,7 +485,7 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() - + self.run_index = -1 process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() @@ -568,11 +582,12 @@ class FlashLlamaModel(torch.nn.Module): seqlen, adapter_data, cross_attention_states, + self.run_index, hpu_attention_meta=hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) - + self.run_index += 1 return hidden_states