From f0dac1dec8545b3aa4add2b6e063835feda54841 Mon Sep 17 00:00:00 2001 From: yuanwu Date: Sun, 11 May 2025 16:44:53 +0000 Subject: [PATCH] Clean the code Signed-off-by: yuanwu --- .../custom_modeling/flash_llama4_modeling.py | 783 ++---------------- .../models/flash_vlm_causal_lm.py | 3 - 2 files changed, 74 insertions(+), 712 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 4ac2ec5d..de8b8955 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -22,13 +22,11 @@ from torch import nn import torch.nn.functional as F from transformers import Llama4TextConfig -from transformers.cache_utils import Cache, HybridChunkedCache +from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPast, - ModelOutput, ) import habana_frameworks.torch as htorch @@ -107,45 +105,31 @@ def apply_rotary_emb( key: torch.Tensor, freqs_ci: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # 调整cos和sin的维度以匹配广播 - print_0(f"freqs_ci: {freqs_ci.shape}") - print_0(f"query: {query.shape}, key: {key.shape}") query_shape = query.shape key_shape = key.shape cos_emb,sin_emb = freqs_ci.split(1, dim=-1) - print_0(f"cos_emb: {cos_emb.shape}, sin_emb: {sin_emb.shape}") - # 将query和key的最后一维拆分为二维向量 + if len(query.shape) == 3: - #query = query.view(freqs_ci.shape[0], -1, *query.shape[-2:]) query = query.unsqueeze(0) key = key.unsqueeze(0) - #key = key.view(freqs_ci.shape[0], -1, *key.shape[-2:]) + query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) - print_0(f"query_reshaped: {query_reshaped.shape}") key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) - print_0(f"key_reshaped: {key_reshaped.shape}") q_shape = query_reshaped.shape[:-1] - print_0(f"q_shape: {q_shape}") cos_emb = reshape_for_broadcast(cos_emb, q_shape) sin_emb = reshape_for_broadcast(sin_emb, q_shape) - print_0(f"cos_emb: {cos_emb.shape}, sin_emb: {sin_emb.shape}") - # 分离x和y分量 x_q, y_q = query_reshaped.unbind(-1) - print_0(f"x_q: {x_q.shape}, y_q: {y_q.shape}") x_k, y_k = key_reshaped.unbind(-1) - print_0(f"x_k: {x_k.shape}, y_k: {y_k.shape}") - # 应用旋转矩阵 + x_q_rot = x_q * cos_emb - y_q * sin_emb y_q_rot = x_q * sin_emb + y_q * cos_emb x_k_rot = x_k * cos_emb - y_k * sin_emb y_k_rot = x_k * sin_emb + y_k * cos_emb - # 合并结果并恢复形状 query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) query_out = query_out.view(*query_shape) key_out = key_out.view(*key_shape) - print_0(f"query_out: {query_out.shape}, key_out: {key_out.shape}") return query_out.type_as(query), key_out.type_as(key) @@ -162,7 +146,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class Llama4TextExperts(nn.Module): - def __init__(self, prefix, config: Llama4TextConfig, weights, layer_idx): + def __init__(self, prefix, config: Llama4TextConfig, weights): super().__init__() self.process_group = weights.process_group self.num_experts = config.num_local_experts @@ -170,26 +154,10 @@ class Llama4TextExperts(nn.Module): self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"textExperts1 Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - - self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"textExperts2 Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - - self.layer_idx = layer_idx self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states: torch.Tensor, run_index) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ This should really not be run on a single machine, as we are reaching compute bound: - the inputs are expected to be "sorted" per expert already. @@ -203,33 +171,12 @@ class Llama4TextExperts(nn.Module): torch.Tensor """ gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) - # if run_index == 0: - # torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") - - down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) - # if run_index == 0: - # torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") - - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - # if run_index == 0: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") - - gate_up = torch.bmm(hidden_states, gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - # if run_index == 0: - # torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") - # torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") - - next_states = torch.bmm((up * self.act_fn(gate)), down_proj) - - next_states = next_states.view(-1, self.hidden_size) - # if run_index == 0: - # torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") # Reduce sum if self.process_group.size() > 1: @@ -240,7 +187,7 @@ class Llama4TextExperts(nn.Module): # Phi3MLP class Llama4TextMLP(nn.Module): - def __init__(self, prefix, config, weights, layer_idx): + def __init__(self, prefix, config, weights): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -260,41 +207,14 @@ class Llama4TextMLP(nn.Module): weights=weights, bias=False, ) - self.layer_idx = layer_idx + self.act_fn = ACT2FN[config.hidden_act] - # self.intermediate_size = ( - # config.intermediate_size // weights.process_group.size() - # ) - # self.config = config - # # self.gate_up_proj = TensorParallelColumnLinear.load_multi( - # # config, - # # prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - # # weights=weights, - # # dim=0, - # # bias=False, - # # ) - # self.gate_proj = TensorParallelColumnLinear.load(config=config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False) - # self.up_proj = TensorParallelColumnLinear.load(config=config, prefix=f"{prefix}.up_proj", weights=weights, bias=False) - # self.down_proj = TensorParallelRowLinear.load(config=config, prefix=f"{prefix}.down_proj", weights=weights, bias=False) - # self.activation_fn = ACT2FN[config.hidden_act] - - - - def forward(self, x, run_index, reduce=True): - shape = x.shape - # gate_up_states = self.gate_up_proj(x) - # gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) - # result = self.down_proj( - # self.activation_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] - # ) - # return result - # down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) - # return self.down_proj(down_proj) + def forward(self, x): gate_up_states = self.gate_up_proj(x) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) return self.down_proj( - self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=True + self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] ) @@ -341,59 +261,21 @@ class Llama4TextMoe(nn.Module): prefix, config, weights, - layer_idx, - moe_layer_cls: Type[MoELayer], ): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - # self.experts = moe_layer_cls( - # prefix=f"{prefix}.experts", - # n_experts=config.num_local_experts, - # n_expert_group=None, - # renormalize=True, - # topk=config.num_experts_per_tok, - # topk_group=None, - # weights=weights, - # scoring_func="sigmoid", - # ) - # assert isinstance(self.experts, MoELayer) - - - self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights, layer_idx=layer_idx) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"TextMode1 Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - - + self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, layer_idx=layer_idx) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"TextMode3 Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) + self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) self.process_group = weights.process_group - self.layer_idx = layer_idx - def forward(self, hidden_states, adapter_data, run_index): + def forward(self, hidden_states, adapter_data): seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) tokens_per_expert = hidden_states.shape[0] router_logits = self.router(hidden_states) - #if run_index != -1: - # torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt") router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( @@ -405,8 +287,6 @@ class Llama4TextMoe(nn.Module): torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - #if run_index != -1: - # torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt") router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) @@ -415,49 +295,20 @@ class Llama4TextMoe(nn.Module): dim=0, index=router_indices, ).to(hidden_states.device) - #if run_index != -1: - # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt") # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) - #if run_index != -1: - # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt") - routed_out = self.experts(routed_in, run_index) - #if run_index != -1: - # torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt") - out = self.shared_expert(hidden_states, run_index, reduce=False) - #if run_index != -1: - # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt") + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)) - # if run_index != -1: - # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") - #Reduce sum - # if self.process_group.size() > 1: - # torch.distributed.all_reduce(out, group=self.process_group) - - #if run_index != -1: - # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") - return out - - # shared_output = self.shared_expert(hidden_states, reduce=False) - # router_logits = self.router(hidden_states) - # out = self.experts(hidden_states, gating_output=router_logits) - - # if shared_output is not None: - # out = out + shared_output - - # # Reduce sum - # if self.process_group.size() > 1: - # torch.distributed.all_reduce(out, group=self.process_group) - - # return out.view(*hidden_states.shape) class Llama4TextRotaryEmbedding(nn.Module): def __init__(self, config: Llama4TextConfig, device=None): super().__init__() @@ -473,8 +324,7 @@ class Llama4TextRotaryEmbedding(nn.Module): inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -484,168 +334,14 @@ class Llama4TextRotaryEmbedding(nn.Module): position_ids_expanded = position_ids_expanded.to(device_type) with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - # 用cos和sin拼接代替复数 cos = torch.cos(freqs) * self.attention_scaling sin = torch.sin(freqs) * self.attention_scaling cos = cos.reshape(-1, 1, cos.shape[-1]) sin = sin.reshape(-1, 1, sin.shape[-1]) - log_master(logger.debug, f"cos: {cos.shape}, sin: {sin.shape}") freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) return freqs_cis -# class Llama4TextRotaryEmbedding(nn.Module): -# def __init__(self, config: Llama4TextConfig, device=None): -# super().__init__() -# # BC: "rope_type" was originally "type" -# self.rope_type = "llama3" if config.rope_scaling is not None else "default" - -# self.max_seq_len_cached = config.max_position_embeddings -# self.original_max_seq_len = config.max_position_embeddings - -# self.config = config -# self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - -# inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) -# self.register_buffer("inv_freq", inv_freq, persistent=False) -# self.original_inv_freq = self.inv_freq - -# @torch.no_grad() -# @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) -# def forward(self, x, position_ids): -# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) -# position_ids_expanded = position_ids[:, None, :].float() -# origin_device = x.device -# device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" and x.device.type != "hpu" else "cpu" -# inv_freq_expanded = inv_freq_expanded.to(device_type) -# position_ids_expanded = position_ids_expanded.to(device_type) -# with torch.autocast(device_type=device_type, enabled=False): # Force float32 -# freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) -# freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation -# freqs_cis = freqs_cis * self.attention_scaling -# return freqs_cis - - -# def apply_rotary_emb( -# xq: torch.Tensor, -# xk: torch.Tensor, -# freqs_cis: torch.Tensor, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# orig_device= xq.device -# xq = xq.to("cpu") -# xk = xk.to("cpu") -# log_master(logger.debug,f"freqs_cis: {freqs_cis.shape}") -# log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") -# xq = xq.view(freqs_cis.shape[0], -1, *xq.shape[-2:]) -# xk = xk.view(freqs_cis.shape[0], -1, *xk.shape[-2:]) -# log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") -# xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) -# xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) -# #log_master(logger.debug, f"xq_: {xq_.shape}, xk_: {xk_.shape}") -# #log_master(logger.debug, f"freqs_cis: {freqs_cis.shape}") -# xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) -# xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) -# xq_out = xq_out.view(-1, *xq_out.shape[-2:]).to(orig_device) -# xk_out = xk_out.view(-1, *xk_out.shape[-2:]).to(orig_device) -# xq = xq.to(orig_device) -# xk = xk.to(orig_device) -# return xq_out.type_as(xq), xk_out.type_as(xk) - - -# class Llama4TextRotaryEmbedding(nn.Module): -# def __init__(self, config: 'Llama4TextConfig', device=None): -# super().__init__() -# self.rope_type = "llama3" if config.rope_scaling is not None else "default" -# self.max_seq_len_cached = config.max_position_embeddings -# self.original_max_seq_len = config.max_position_embeddings -# self.config = config -# self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - -# inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) -# self.register_buffer("inv_freq", inv_freq, persistent=False) -# self.original_inv_freq = self.inv_freq - -# @torch.no_grad() -# @dynamic_rope_update -# def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: -# """ -# Args: -# x: Input tensor of shape [batch, seq_len, heads, dim] -# position_ids: Position indices of shape [batch, seq_len] -# Returns: -# Rotary embeddings as float tensors [batch, seq_len, heads, dim] -# """ -# # Expand inv_freq and position_ids for broadcasting -# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) -# position_ids_expanded = position_ids[:, None, :].float() - -# # Compute frequencies (replaces complex phase) -# freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) # [batch, seq_len, dim//2] - -# # Generate cos/sin components directly (replaces torch.polar) -# cos_vals = torch.cos(freqs) * self.attention_scaling -# sin_vals = torch.sin(freqs) * self.attention_scaling - -# # Interleave cos/sin values to match original complex format -# dim = x.size(-1) -# if dim % 2 != 0: -# raise ValueError(f"Feature dimension {dim} must be even for Rotary Embedding") - -# # Stack and reshape to [batch, seq_len, dim] format -# freqs_cis = torch.stack([cos_vals, sin_vals], dim=-1) # [batch, seq_len, dim//2, 2] -# freqs_cis = freqs_cis.reshape(*freqs_cis.shape[:-2], dim) # [batch, seq_len, dim] - -# return freqs_cis - -# def apply_rotary_emb( -# xq: torch.Tensor, -# xk: torch.Tensor, -# freqs_cis: torch.Tensor, # Should be [cosθ, sinθ] instead of complex numbers -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# """ -# Apply rotary position embedding to query and key tensors using floating-point operations only. - -# Args: -# xq: Query tensor of shape (batch, seq_len, n_heads, head_dim) -# xk: Key tensor of shape (batch, seq_len, n_heads, head_dim) -# freqs_cis: Precomputed rotation frequencies as [cosθ, sinθ] -# of shape (batch, seq_len, head_dim//2, 2) -# Returns: -# Rotated query and key tensors with same shape as input -# """ -# # Verify head_dim is even -# assert xq.size(-1) % 2 == 0, "Feature dimension must be even for rotary embedding" - -# # Reshape to separate real and imaginary components (pairs of adjacent elements) -# xq_reshaped = xq.float().reshape(*xq.shape[:-1], -1, 2) # [..., head_dim//2, 2] -# xk_reshaped = xk.float().reshape(*xk.shape[:-1], -1, 2) # [..., head_dim//2, 2] - -# # Extract cosθ and sinθ (assuming freqs_cis is already in [cosθ, sinθ] format) -# cos_theta = freqs_cis[..., 0] # [batch, seq_len, head_dim//2] -# sin_theta = freqs_cis[..., 1] # [batch, seq_len, head_dim//2] - -# # Expand dimensions for broadcasting [batch, seq_len, n_heads, head_dim//2] -# cos_theta = cos_theta.unsqueeze(2) # Add n_heads dimension -# sin_theta = sin_theta.unsqueeze(2) - -# # Rotary transformation (mathematically equivalent to complex multiplication) -# # xq_rotated = [xq0*cosθ - xq1*sinθ, xq0*sinθ + xq1*cosθ] -# xq_out = torch.stack([ -# xq_reshaped[..., 0] * cos_theta - xq_reshaped[..., 1] * sin_theta, -# xq_reshaped[..., 0] * sin_theta + xq_reshaped[..., 1] * cos_theta -# ], dim=-1) - -# xk_out = torch.stack([ -# xk_reshaped[..., 0] * cos_theta - xk_reshaped[..., 1] * sin_theta, -# xk_reshaped[..., 0] * sin_theta + xk_reshaped[..., 1] * cos_theta -# ], dim=-1) - -# # Restore original shape -# xq_out = xq_out.flatten(-2) # [batch, seq_len, n_heads, head_dim] -# xk_out = xk_out.flatten(-2) - -# # Maintain original dtype -# return xq_out.type_as(xq), xk_out.type_as(xk) class Llama4TextAttention(FlashLlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -666,13 +362,6 @@ class Llama4TextAttention(FlashLlamaAttention): self.is_causal = True self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - self.rotary_emb = PositionRotaryEmbedding.static( - config=config, - dim=self.head_size, - base=config.rope_theta, - device=weights.device, - ) - # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr( config, "attention_multiplier", self.head_dim**-0.5 @@ -734,19 +423,6 @@ class Llama4TextAttention(FlashLlamaAttention): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - - # self.q_proj = nn.Linear( - # config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - # ) - # self.k_proj = nn.Linear( - # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - # ) - # self.v_proj = nn.Linear( - # config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - # ) - # self.o_proj = nn.Linear( - # config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - # ) if self.config.use_qk_norm and self.use_rope: self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) @@ -759,7 +435,6 @@ class Llama4TextAttention(FlashLlamaAttention): slots, seqlen, adapter_data, - run_index, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, @@ -791,35 +466,15 @@ class Llama4TextAttention(FlashLlamaAttention): # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) - #if run_index != -1: - # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") - # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") - # torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") - if self.use_rope: # the 16E model skips rope for long context on certain layers - #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) - #cos, sin = freqs_ci - #log_master(logger.debug, f"cos: {cos.shape}, sin: {sin.shape}") - log_master(logger.debug, f"query_states: {query_states.shape}, key_states: {key_states.shape}") - #self.rotary_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_emb( query_states, key_states, freqs_ci ) - - - #if run_index != -1: - # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") - # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt") - - if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm query_states = self.qk_norm(query_states) key_states = self.qk_norm(key_states) - #if run_index != -1: - # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt") - # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") # query_states = query_states.view(-1, self.num_heads, self.head_size) @@ -836,26 +491,15 @@ class Llama4TextAttention(FlashLlamaAttention): ) # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers if self.attn_temperature_tuning and not self.use_rope: - #indice = torch.tensor([0]).to(query_states.device) - #cache_position = position_ids - #log_master(logger.debug, f"cache_position: {cache_position.shape}") - - attn_scales = ( torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 ) - #seq_len = input_shape / bs attn_scales = attn_scales.view(*input_shape, 1, 1) query_states = (query_states * attn_scales).to(query_states.dtype) - #if run_index != -1: - # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt") - # torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt") - # Prefill if cu_seqlen_prefill is not None: # sdpa - # log_master(logger.debug, f"self.softmax_scale: {self.softmax_scale}") # attn_output = attention( # query=query_states, # key=key_states, @@ -905,8 +549,6 @@ class Llama4TextAttention(FlashLlamaAttention): ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - #if run_index != -1: - # torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt") attn_output = self.o_proj(attn_output) return attn_output @@ -916,29 +558,12 @@ class Llama4TextDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"layer_idx: {layer_idx} Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - - - self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.is_moe_layer = layer_idx in config.moe_layers - # log_master(logger.debug, f"self.is_moe_layer: {self.is_moe_layer}, layer_idx:{layer_idx}") - # log_master(logger.debug, f"moe_layers:{config.moe_layers}") if self.is_moe_layer: # the 128E model interleaves dense / sparse - moe_layer_cls = ( - SparseMoELayer - if SparseMoELayer.is_supported(weights) - else DenseMoELayer - ) - - self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights, layer_idx, moe_layer_cls) + self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights) else: - self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx) + self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) @@ -953,9 +578,6 @@ class Llama4TextDecoderLayer(nn.Module): # eps=config.rms_norm_eps, # ) - - self.layer_idx = layer_idx - def forward( self, hidden_states, @@ -969,14 +591,9 @@ class Llama4TextDecoderLayer(nn.Module): chunk_causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, - run_index: int = 0, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt") hidden_states = self.input_layernorm(hidden_states) - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") # use local attention mask for ROPE layers if self.use_chunked_attention and chunk_causal_mask is not None: @@ -990,54 +607,20 @@ class Llama4TextDecoderLayer(nn.Module): slots, seqlen, adapter_data, - run_index, attention_mask=attention_mask, position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) - #if run_index != -1: - # torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt") + hidden_states = residual + attention_states - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt") # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") - hidden_states = self.feed_forward(hidden_states, adapter_data, run_index) - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") + hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = residual + hidden_states.view(residual.shape) - #if run_index != -1: - # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt") - #outputs = (hidden_states,) return hidden_states - # if residual is None: - # residual = hidden_states - # hidden_states, _ = self.input_layernorm(hidden_states) - # else: - # hidden_states, residual = self.input_layernorm( - # hidden_states, residual) - # hidden_states = self.self_attn( - # hidden_states, - # cos, - # sin, - # cu_seqlen_prefill, - # kv_cache, - # slots, - # seqlen, - # adapter_data, - # hpu_attention_meta=hpu_attention_meta, - # ) - - # # Fully Connected - # hidden_states, residual = self.post_attention_layernorm( - # hidden_states, residual) - # hidden_states = self.feed_forward(hidden_states, adapter_data) - # return hidden_states, residual class Llama4TextModel(nn.Module): @@ -1048,16 +631,6 @@ class Llama4TextModel(nn.Module): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"textModel Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - # log_master( - # logger.debug, - # f"config.num_hidden_layers: {config.num_hidden_layers} " - # ) self.layers = nn.ModuleList( [Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -1068,7 +641,6 @@ class Llama4TextModel(nn.Module): weights=weights, eps=config.rms_norm_eps, ) - self.run_index = -1 self.rotary_emb = Llama4TextRotaryEmbedding(config=config) self.gradient_checkpointing = False @@ -1087,29 +659,17 @@ class Llama4TextModel(nn.Module): ) -> torch.Tensor: hidden_states = inputs_embeds - #if self.run_index != -1: - # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") - #log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") - # Get rotary cos and sin for this forward - # Avoid to index in each layer - #cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) - #log_master(logger.debug, f"position_ids.shape={position_ids.shape}, position_ids={position_ids}") bs = seqlen.input_lengths.shape[0] seq_len = inputs_embeds.shape[0] / bs cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # log_master(logger.debug, f"cache_position={cache_position}") - # log_master(logger.debug, f"position_ids={position_ids}") + causal_mask, chunk_causal_mask = self._update_causal_mask( attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False ) - # log_master(logger.debug, f"causal_mask={causal_mask}") - # log_master(logger.debug, f"causal_mask={causal_mask.shape}") - # log_master(logger.debug, f"chunk_causal_mask={chunk_causal_mask}") - #freqs_ci = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) for i, layer in enumerate(self.layers): @@ -1125,17 +685,11 @@ class Llama4TextModel(nn.Module): chunk_causal_mask=chunk_causal_mask, position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, - run_index=self.run_index, ) - # if self.run_index == 0: - # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") hidden_states, _ = self.norm(hidden_states) - # if self.run_index == 0: - # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") - self.run_index += 1 return hidden_states def _update_causal_mask( @@ -1417,7 +971,6 @@ def pixel_shuffle(input_tensor, shuffle_ratio): batch_size, num_patches, channels = input_tensor.shape patch_size = int(math.sqrt(num_patches)) - print_0(f"pixel_shuffle: {input_tensor.shape}, patch_size: {patch_size}, shuffle_ratio: {shuffle_ratio}") input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() torch_save(input_tensor, f"pixel_shuffle.input_tensor.pt") @@ -1455,51 +1008,6 @@ def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): return freqs_ci.view(*shape) - -# def vision_apply_rotary_emb( -# query: torch.Tensor, -# key: torch.Tensor, -# freqs_ci: torch.Tensor, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# cos_cache, sin_cache = freqs_ci.chunk(2, dim=-1) -# # shape: [577, 1, 44] -# #print(f"[DENBUG] cos_cache.shape: {cos_cache.shape}, sin_cache.shape: {sin_cache.shape}") - -# query_2d = query.float().reshape(*query.shape[:-1], -1, 2) -# key_2d = key.float().reshape(*key.shape[:-1], -1, 2) -# # e.g., [17, 577, 8, 44, 2] -# #print(f'[DEBUG] query_2d.shape: {query_2d.shape}, key_2d.shape: {key_2d.shape}') - -# # Reshape cos_cache and sin_cache to broadcast properly. -# # We want them to have shape [1, 577, 1, 44] to match the query dimensions (except for the last two dims). -# cos_cache = cos_cache.view(1, cos_cache.shape[0], 1, cos_cache.shape[-1]) -# sin_cache = sin_cache.view(1, sin_cache.shape[0], 1, sin_cache.shape[-1]) -# # e.g., [1, 577, 1, 44] - -# # Separate the real and imaginary parts. -# q_real, q_imag = query_2d.unbind(-1) # each: [17, 577, 8, 44] -# k_real, k_imag = key_2d.unbind(-1) # each: [17, 577, 8, 44] - -# # Manually apply the complex multiplication (rotation) using the trigonometric identities. -# # For a complex multiplication: (a+ib)*(c+id) = (ac - bd) + i(ad + bc) -# q_rotated_real = q_real * cos_cache - q_imag * sin_cache -# q_rotated_imag = q_real * sin_cache + q_imag * cos_cache - -# k_rotated_real = k_real * cos_cache - k_imag * sin_cache -# k_rotated_imag = k_real * sin_cache + k_imag * cos_cache - -# # Re-stack the rotated components into a last dimension of size 2. -# q_rotated = torch.stack([q_rotated_real, q_rotated_imag], dim=-1) # shape: [17, 577, 8, 44, 2] -# k_rotated = torch.stack([k_rotated_real, k_rotated_imag], dim=-1) # shape: [17, 577, 8, 44, 2] - -# # Flatten the last two dimensions to match the original output shape. -# # Flatten back to the desired shape (e.g., collapse the last two dimensions). -# query_out = q_rotated.flatten(3) -# key_out = k_rotated.flatten(3) - -# return query_out.type_as(query), key_out.type_as(key) - - class Llama4VisionAttention(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -1517,105 +1025,81 @@ class Llama4VisionAttention(nn.Module): weights=weights, bias=True, ) - self.k_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=True, - ) - self.v_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=True, - ) - self.o_proj = TensorParallelRowLinear.load( - config=config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=True, - ) - # self.qkv_proj = TensorParallelColumnLinear.load_multi( - # config, - # prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - # dim=0, + # self.k_proj = TensorParallelColumnLinear.load( + # config=config, + # prefix=f"{prefix}.k_proj", + # weights=weights, + # bias=True, + # ) + # self.v_proj = TensorParallelColumnLinear.load( + # config=config, + # prefix=f"{prefix}.v_proj", # weights=weights, # bias=True, # ) # self.o_proj = TensorParallelRowLinear.load( - # config, + # config=config, # prefix=f"{prefix}.o_proj", # weights=weights, # bias=True, # ) + self.qkv_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=True, + ) def forward( self, hidden_states: torch.Tensor, freqs_ci: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex attention_mask: Optional[torch.Tensor] = None, - run_index: Optional[int] = None, - layer_idx: Optional[int] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.input.pt") - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(hidden_shape) - value_states = self.v_proj(hidden_states).view(hidden_shape) - # qkv = self.qkv_proj(hidden_states) + # query_states = self.q_proj(hidden_states).view(hidden_shape) + # key_states = self.k_proj(hidden_states).view(hidden_shape) + # value_states = self.v_proj(hidden_states).view(hidden_shape) + qkv = self.qkv_proj(hidden_states) - # query_states, key_states, value_states = qkv.split( - # [ - # self.head_dim * self.num_heads, - # self.head_dim * self.num_heads, - # self.head_dim * self.num_heads, - # ], - # dim=2, - # ) - # query_states = query_states.view(hidden_shape) - # key_states = key_states.view(hidden_shape) - # value_states = value_states.view(hidden_shape) - #if run_index != -1: - # torch_save(query_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.query_states.pt") - # torch_save(key_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.key_states.pt") - # torch_save(value_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.value_states.pt") - #query_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.query_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - #key_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.key_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - #value_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.value_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - - log_master( - logger.debug, - f"vision query_states.shape: {query_states.shape}, key_states.shape: {key_states.shape}, freqs_ci.shape: {freqs_ci.shape}" + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + self.head_dim * self.num_heads, + ], + dim=2, ) + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) - #if run_index != -1: - #torch_save(query_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.query_states.pt") - #torch_save(key_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.key_states.pt") - #query_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.query_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - #key_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.key_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - - + query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - #print(f"attention_mask shape: {attention_mask.shape}") - print(f"attention_mask: {attention_mask}") - if hasattr(self, "num_key_value_groups"): - print_0(f"module.num_key_value_groups={self.num_key_value_groups}") - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + + # if hasattr(self, "num_key_value_groups"): + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 ) - attn_output = attn_output.transpose(1, 2).contiguous() - #attn_output = torch.load(f"trans.{run_index}.layer.{layer_idx}.self_attn.attn_output.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output @@ -1665,28 +1149,18 @@ class Llama4VisionEncoderLayer(nn.Module): hidden_state: torch.Tensor, freqs_ci: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - run_index: Optional[int] = None, - layer_idx: Optional[int] = None, ): # Self Attention residual = hidden_state - if run_index != -1: - torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.input.pt") + hidden_state = self.input_layernorm(hidden_state) - if run_index != -1: - torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.input_norm.pt") - torch_save(attention_mask, f"trans.{run_index}.encoder.layer.{layer_idx}.attention_mask.pt") - torch_save(freqs_ci, f"trans.{run_index}.encoder.layer.{layer_idx}.freqs_ci.pt") + hidden_state = self.self_attn( hidden_state, freqs_ci=freqs_ci, attention_mask=attention_mask, - run_index=run_index, - layer_idx=layer_idx, ) - #if run_index != -1: - #torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.atten.pt") - #hidden_state = torch.load(f"trans.{run_index}.encoder.layer.{layer_idx}.atten.pt").to(device=hidden_state.device,dtype=hidden_state.dtype) + hidden_state = residual + hidden_state # Feed forward @@ -1694,12 +1168,7 @@ class Llama4VisionEncoderLayer(nn.Module): hidden_state = self.post_attention_layernorm(hidden_state) hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state - - if run_index != -1: - torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.output.pt") outputs = (hidden_state,) - - return outputs @@ -1721,7 +1190,6 @@ class Llama4VisionEncoder(nn.Module): ]) self.gradient_checkpointing = False self.config = config - self.run_index = -1 def forward( self, @@ -1730,21 +1198,15 @@ class Llama4VisionEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: - for layer_idx, encoder_layer in enumerate(self.layers): + for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_state=hidden_states, attention_mask=attention_mask, freqs_ci=freqs_ci, - run_index=self.run_index, - layer_idx=layer_idx, ) - hidden_states = layer_outputs[0] - if self.run_index != -1: - torch_save(hidden_states, f"trans.{self.run_index}.encoder.output.pt") - #hidden_states = torch.load(f"trans.{self.run_index}.encoder.output.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - self.run_index += 1 + return hidden_states @@ -1769,62 +1231,21 @@ class Llama4UnfoldConvolution(nn.Module): hidden_states = self.linear(hidden_states) return hidden_states -# class Llama4VisionRotaryEmbedding(nn.Module): -# def __init__(self, config, weights): -# super().__init__() -# idx = config.image_size // config.patch_size -# print_0(f"VisionRotaryEmbedding idx: {idx}") -# img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) -# img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) -# print_0(f"VisionRotaryEmbedding img_idx: {img_idx.shape}") -# torch_save(img_idx, f"trans.vision.img_idx.pt") -# img_idx[-1, -1] = -2 # ID_CLS_TOKEN -# print_0(f"VisionRotaryEmbedding img_idx: {img_idx}, img_idx.dtype: {img_idx.dtype}") -# frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x -# torch_save(frequencies_x, f"trans.vision.frequencies_x.pt") -# frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y -# print_0(f"VisionRotaryEmbedding frequencies_y: {frequencies_y}") -# torch_save(frequencies_y, f"trans.vision.frequencies_y.pt") -# freq_dim = config.hidden_size // config.num_attention_heads // 2 -# rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) -# torch_save(rope_freq, f"trans.vision.rope_freq.pt") -# freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) -# torch_save(freqs_x, f"trans.vision.freqs_x.pt") -# freqs_x = freqs_x.repeat_interleave(2, dim=-1) -# torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt") -# freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) -# torch_save(freqs_y, f"trans.vision.freqs_y.pt") -# freqs_y = freqs_y.repeat_interleave(2, dim=-1) -# torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt") - -# freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] -# freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) -# torch_save(freqs, f"trans.vision.freqs.pt") -# #freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) -# freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) -# self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 - -# def forward(self, hidden_states): -# return self.freqs_ci - class Llama4VisionRotaryEmbedding(nn.Module): def __init__(self, config, weights): super().__init__() # Calculate image grid indices idx = config.image_size // config.patch_size - print_0(f"VisionRotaryEmbedding idx: {idx}") img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) torch_save(img_idx, f"trans.vision.img_idx.pt") img_idx[-1, -1] = -2 # ID_CLS_TOKEN - print_0(f"VisionRotaryEmbedding img_idx: {img_idx}, img_idx.dtype: {img_idx.dtype}") # Calculate x and y coordinates frequencies_x = img_idx % idx # x coordinates torch_save(frequencies_x, f"trans.vision.frequencies_x.pt") frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates - print_0(f"VisionRotaryEmbedding frequencies_y: {frequencies_y}") torch_save(frequencies_y, f"trans.vision.frequencies_y.pt") # Calculate frequency components freq_dim = config.hidden_size // config.num_attention_heads // 2 @@ -1904,7 +1325,6 @@ class Llama4VisionModel(nn.Module): self.vision_adapter = Llama4VisionPixelShuffleMLP( prefix=f"{prefix}.vision_adapter", config=config, weights=weights ) - self.run_index = -1 def forward( self, @@ -1912,18 +1332,12 @@ class Llama4VisionModel(nn.Module): attention_mask: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ): - if self.run_index != -1: - torch_save(pixel_values, f"trans.{self.run_index}.vision.pixel_values.pt") - # num_concurrent_media and num_chunks are both currently 1 batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape num_concurrent_media = 1 num_chunks = 1 hidden_state = self.patch_embedding(pixel_values) _, num_patches, hidden_dim = hidden_state.shape - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.patch.pt") - # Add cls token hidden_state = hidden_state.reshape( @@ -1932,48 +1346,29 @@ class Llama4VisionModel(nn.Module): class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.class.pt") + # Position embeddings hidden_state = hidden_state.reshape( batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim ) positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) hidden_state = hidden_state + positional_embedding - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.position.pt") - hidden_state = self.layernorm_pre(hidden_state) - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.layernorm_pre.pt") - hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) freqs_ci = self.rotary_embedding(pixel_values) - if self.run_index != -1: - torch_save(freqs_ci, f"trans.{self.run_index}.vision.freqs_ci.pt") hidden_state = self.model( hidden_state, attention_mask=None, freqs_ci=freqs_ci, ) - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.model.pt") - - hidden_state = self.layernorm_post(hidden_state) - if self.run_index != -1: - torch_save(hidden_state, f"trans.{self.run_index}.vision.post.pt") hidden_state = hidden_state[:, :-1, :] # now, we use Llama4VisionPixelShuffle + mlp to project embeddings hidden_state = self.vision_adapter(hidden_state) - #if self.run_index != -1: - #hidden_state = torch.load(f"trans.{self.run_index}.vision.hidden_states.pt").to(device=hidden_state.device,dtype=hidden_state.dtype) - #torch_save(hidden_state, f"trans.{self.run_index}.vision.hidden_states.pt") - self.run_index += 1 return hidden_state class Llama4ForConditionalGeneration(nn.Module): @@ -1986,32 +1381,15 @@ class Llama4ForConditionalGeneration(nn.Module): config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator config.text_config._attn_implementation = None - log_master( - logger.debug, - f"init Llama4ForConditionalGeneration with config!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" - ) + self.vision_model = Llama4VisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - self.multi_modal_projector = Llama4MultiModalProjector( prefix="multi_modal_projector", config=config, weights=weights ) - # synchronize(weights.device) - # real_free_memory = get_free_memory(weights.device, 1) - # log_master( - # logger.debug, - # f"Free memory real: {real_free_memory / 1e9:.2f}GB" - # ) - self.text_model = Llama4ForCausalLM( prefix="language_model", config=config.text_config, weights=weights ) @@ -2069,18 +1447,12 @@ class Llama4ForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, **lm_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - # log_master( - # logger.debug, - # f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" - # ) def _get_padding_mask(input_ids, pad_token_id=0): - return (input_ids != pad_token_id).long() # 非填充位置为1,填充位置为0 + return (input_ids != pad_token_id).long() - # 示例 attention_mask = _get_padding_mask(input_ids) attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1) - #log_master(logger.debug,f"attention_mask={attention_mask}") inputs_embeds = self.text_model.model.embed_tokens(input_ids) vision_feature_layer = ( vision_feature_layer @@ -2093,13 +1465,6 @@ class Llama4ForConditionalGeneration(nn.Module): else self.config.vision_config.vision_feature_select_strategy ) - # if (input_ids is None) ^ (inputs_embeds is not None): - # raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - # if pixel_values is not None and inputs_embeds is not None: - # raise ValueError( - # "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - # ) if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index b99fea31..a8f3591f 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -338,9 +338,6 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): image_id += 1 full_text = image_text_replacement_fixup(config, full_text) - log_master( - logger.debug, f"full_text: {full_text}" - ) input_ids = tokenizer( full_text, truncation=True,