diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index f7bef8c3..4ac2ec5d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -30,9 +30,12 @@ from transformers.modeling_outputs import ( BaseModelOutputWithPast, ModelOutput, ) + +import habana_frameworks.torch as htorch from transformers.processing_utils import Unpack from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers import ( TensorParallelColumnLinear, @@ -82,12 +85,69 @@ def print_0(*args, **kwargs): print(*args, **kwargs) else: # 如果不是分布式环境,正常打印 - print(*args, **kwargs) + print(*args, **kwargs, flush=True) def torch_save(tensor, name): + pass # Only save on the main process (rank 0) when using distributed training - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - torch.save(tensor, name) + # if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + # torch.save(tensor, name) +def torch_load(name): + rank = torch.distributed.get_rank() + return torch.load(f"{name}.{rank}") + + +def reshape_for_broadcast(freqs: torch.Tensor, target): + ndim = len(target) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)] + return freqs.view(*shape) + +def apply_rotary_emb( + query: torch.Tensor, + key: torch.Tensor, + freqs_ci: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # 调整cos和sin的维度以匹配广播 + print_0(f"freqs_ci: {freqs_ci.shape}") + print_0(f"query: {query.shape}, key: {key.shape}") + query_shape = query.shape + key_shape = key.shape + cos_emb,sin_emb = freqs_ci.split(1, dim=-1) + print_0(f"cos_emb: {cos_emb.shape}, sin_emb: {sin_emb.shape}") + # 将query和key的最后一维拆分为二维向量 + if len(query.shape) == 3: + #query = query.view(freqs_ci.shape[0], -1, *query.shape[-2:]) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + #key = key.view(freqs_ci.shape[0], -1, *key.shape[-2:]) + query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) + print_0(f"query_reshaped: {query_reshaped.shape}") + key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) + print_0(f"key_reshaped: {key_reshaped.shape}") + q_shape = query_reshaped.shape[:-1] + print_0(f"q_shape: {q_shape}") + cos_emb = reshape_for_broadcast(cos_emb, q_shape) + sin_emb = reshape_for_broadcast(sin_emb, q_shape) + print_0(f"cos_emb: {cos_emb.shape}, sin_emb: {sin_emb.shape}") + # 分离x和y分量 + x_q, y_q = query_reshaped.unbind(-1) + print_0(f"x_q: {x_q.shape}, y_q: {y_q.shape}") + x_k, y_k = key_reshaped.unbind(-1) + print_0(f"x_k: {x_k.shape}, y_k: {y_k.shape}") + # 应用旋转矩阵 + x_q_rot = x_q * cos_emb - y_q * sin_emb + y_q_rot = x_q * sin_emb + y_q * cos_emb + x_k_rot = x_k * cos_emb - y_k * sin_emb + y_k_rot = x_k * sin_emb + y_k * cos_emb + + # 合并结果并恢复形状 + query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) + key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) + query_out = query_out.view(*query_shape) + key_out = key_out.view(*key_shape) + print_0(f"query_out: {query_out.shape}, key_out: {key_out.shape}") + return query_out.type_as(query), key_out.type_as(key) + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -95,7 +155,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape - print_0(f"batch={batch}, num_key_value_heads={num_key_value_heads}, slen={slen}, head_dim={head_dim}") if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) @@ -111,21 +170,21 @@ class Llama4TextExperts(nn.Module): self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"textExperts1 Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"textExperts1 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"textExperts2 Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"textExperts2 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.layer_idx = layer_idx self.act_fn = ACT2FN[config.hidden_act] @@ -144,33 +203,33 @@ class Llama4TextExperts(nn.Module): torch.Tensor """ gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) - if run_index == 0: - torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") + # if run_index == 0: + # torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) - if run_index == 0: - torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") + # if run_index == 0: + # torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - if run_index == 0: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") + # if run_index == 0: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") gate_up = torch.bmm(hidden_states, gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - if run_index == 0: - torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") - torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") + # if run_index == 0: + # torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") + # torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") next_states = torch.bmm((up * self.act_fn(gate)), down_proj) next_states = next_states.view(-1, self.hidden_size) - if run_index == 0: - torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") + # if run_index == 0: + # torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") # Reduce sum if self.process_group.size() > 1: @@ -289,7 +348,6 @@ class Llama4TextMoe(nn.Module): self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - log_master(logger.debug, f"weights.load: {weights.loader}") # self.experts = moe_layer_cls( # prefix=f"{prefix}.experts", # n_experts=config.num_local_experts, @@ -304,28 +362,28 @@ class Llama4TextMoe(nn.Module): self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights, layer_idx=layer_idx) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"TextMode1 Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"TextMode1 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, layer_idx=layer_idx) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"TextMode3 Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"TextMode3 Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.process_group = weights.process_group self.layer_idx = layer_idx @@ -400,7 +458,6 @@ class Llama4TextMoe(nn.Module): # torch.distributed.all_reduce(out, group=self.process_group) # return out.view(*hidden_states.shape) - class Llama4TextRotaryEmbedding(nn.Module): def __init__(self, config: Llama4TextConfig, device=None): super().__init__() @@ -416,45 +473,83 @@ class Llama4TextRotaryEmbedding(nn.Module): inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() origin_device = x.device - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" and x.device.type != "hpu" else "cpu" + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" inv_freq_expanded = inv_freq_expanded.to(device_type) position_ids_expanded = position_ids_expanded.to(device_type) with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation - freqs_cis = freqs_cis * self.attention_scaling + # 用cos和sin拼接代替复数 + cos = torch.cos(freqs) * self.attention_scaling + sin = torch.sin(freqs) * self.attention_scaling + cos = cos.reshape(-1, 1, cos.shape[-1]) + sin = sin.reshape(-1, 1, sin.shape[-1]) + log_master(logger.debug, f"cos: {cos.shape}, sin: {sin.shape}") + freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling + freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) return freqs_cis +# class Llama4TextRotaryEmbedding(nn.Module): +# def __init__(self, config: Llama4TextConfig, device=None): +# super().__init__() +# # BC: "rope_type" was originally "type" +# self.rope_type = "llama3" if config.rope_scaling is not None else "default" -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - orig_device= xq.device - xq = xq.to("cpu") - xk = xk.to("cpu") - xq = xq.view(freqs_cis.shape[0], -1, *xq.shape[-2:]) - xk = xk.view(freqs_cis.shape[0], -1, *xk.shape[-2:]) - #log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - #log_master(logger.debug, f"xq_: {xq_.shape}, xk_: {xk_.shape}") - #log_master(logger.debug, f"freqs_cis: {freqs_cis.shape}") - xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) - xq_out = xq_out.view(-1, *xq_out.shape[-2:]).to(orig_device) - xk_out = xk_out.view(-1, *xk_out.shape[-2:]).to(orig_device) - xq = xq.to(orig_device) - xk = xk.to(orig_device) - return xq_out.type_as(xq), xk_out.type_as(xk) +# self.max_seq_len_cached = config.max_position_embeddings +# self.original_max_seq_len = config.max_position_embeddings + +# self.config = config +# self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + +# inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) +# self.register_buffer("inv_freq", inv_freq, persistent=False) +# self.original_inv_freq = self.inv_freq + +# @torch.no_grad() +# @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) +# def forward(self, x, position_ids): +# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) +# position_ids_expanded = position_ids[:, None, :].float() +# origin_device = x.device +# device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" and x.device.type != "hpu" else "cpu" +# inv_freq_expanded = inv_freq_expanded.to(device_type) +# position_ids_expanded = position_ids_expanded.to(device_type) +# with torch.autocast(device_type=device_type, enabled=False): # Force float32 +# freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) +# freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation +# freqs_cis = freqs_cis * self.attention_scaling +# return freqs_cis + + +# def apply_rotary_emb( +# xq: torch.Tensor, +# xk: torch.Tensor, +# freqs_cis: torch.Tensor, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# orig_device= xq.device +# xq = xq.to("cpu") +# xk = xk.to("cpu") +# log_master(logger.debug,f"freqs_cis: {freqs_cis.shape}") +# log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") +# xq = xq.view(freqs_cis.shape[0], -1, *xq.shape[-2:]) +# xk = xk.view(freqs_cis.shape[0], -1, *xk.shape[-2:]) +# log_master(logger.debug, f"xq: {xq.shape}, xk: {xk.shape}") +# xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) +# xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) +# #log_master(logger.debug, f"xq_: {xq_.shape}, xk_: {xk_.shape}") +# #log_master(logger.debug, f"freqs_cis: {freqs_cis.shape}") +# xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) +# xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) +# xq_out = xq_out.view(-1, *xq_out.shape[-2:]).to(orig_device) +# xk_out = xk_out.view(-1, *xk_out.shape[-2:]).to(orig_device) +# xq = xq.to(orig_device) +# xk = xk.to(orig_device) +# return xq_out.type_as(xq), xk_out.type_as(xk) # class Llama4TextRotaryEmbedding(nn.Module): @@ -571,6 +666,13 @@ class Llama4TextAttention(FlashLlamaAttention): self.is_causal = True self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + # `config.attention_multiplier` is used in Granite self.softmax_scale = getattr( config, "attention_multiplier", self.head_dim**-0.5 @@ -696,6 +798,9 @@ class Llama4TextAttention(FlashLlamaAttention): if self.use_rope: # the 16E model skips rope for long context on certain layers #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) + #cos, sin = freqs_ci + #log_master(logger.debug, f"cos: {cos.shape}, sin: {sin.shape}") + log_master(logger.debug, f"query_states: {query_states.shape}, key_states: {key_states.shape}") #self.rotary_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_emb( query_states, key_states, freqs_ci @@ -764,7 +869,6 @@ class Llama4TextAttention(FlashLlamaAttention): query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - print_0(f"self.num_key_value_groups={self.num_key_value_groups}") key = repeat_kv(key, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups) @@ -777,11 +881,6 @@ class Llama4TextAttention(FlashLlamaAttention): query = query.contiguous() key = key.contiguous() value = value.contiguous() - print_0(f"query.shape={query.shape}, query={query}") - print_0(f"key.shape={key.shape}, key={key}") - print_0(f"value.shape={value.shape}, value={value}") - print_0(f"attention_mask.shape={causal_mask.shape}, attention_mask={causal_mask}") - print_0(f"scaling={self.scaling}, is_causal={is_causal}") attn_output = torch.nn.functional.scaled_dot_product_attention( query, @@ -819,17 +918,17 @@ class Llama4TextDecoderLayer(nn.Module): self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"layer_idx: {layer_idx} Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # log_master( + # logger.debug, + # f"layer_idx: {layer_idx} Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.is_moe_layer = layer_idx in config.moe_layers - log_master(logger.debug, f"self.is_moe_layer: {self.is_moe_layer}, layer_idx:{layer_idx}") - log_master(logger.debug, f"moe_layers:{config.moe_layers}") + # log_master(logger.debug, f"self.is_moe_layer: {self.is_moe_layer}, layer_idx:{layer_idx}") + # log_master(logger.debug, f"moe_layers:{config.moe_layers}") if self.is_moe_layer: # the 128E model interleaves dense / sparse moe_layer_cls = ( SparseMoELayer @@ -949,16 +1048,16 @@ class Llama4TextModel(nn.Module): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"textModel Free memory real: {real_free_memory / 1e9:.2f}GB" - ) - log_master( - logger.debug, - f"config.num_hidden_layers: {config.num_hidden_layers} " - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"textModel Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) + # log_master( + # logger.debug, + # f"config.num_hidden_layers: {config.num_hidden_layers} " + # ) self.layers = nn.ModuleList( [Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -990,30 +1089,29 @@ class Llama4TextModel(nn.Module): hidden_states = inputs_embeds #if self.run_index != -1: # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") - log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") + #log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") # Get rotary cos and sin for this forward # Avoid to index in each layer #cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) - log_master(logger.debug, f"position_ids.shape={position_ids.shape}, position_ids={position_ids}") + #log_master(logger.debug, f"position_ids.shape={position_ids.shape}, position_ids={position_ids}") bs = seqlen.input_lengths.shape[0] seq_len = inputs_embeds.shape[0] / bs cache_position = torch.arange(0, seq_len, device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) - log_master(logger.debug, f"cache_position={cache_position}") - log_master(logger.debug, f"position_ids={position_ids}") + # log_master(logger.debug, f"cache_position={cache_position}") + # log_master(logger.debug, f"position_ids={position_ids}") causal_mask, chunk_causal_mask = self._update_causal_mask( attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False ) - log_master(logger.debug, f"causal_mask={causal_mask}") - log_master(logger.debug, f"causal_mask={causal_mask.shape}") - log_master(logger.debug, f"chunk_causal_mask={chunk_causal_mask}") - - - - + # log_master(logger.debug, f"causal_mask={causal_mask}") + # log_master(logger.debug, f"causal_mask={causal_mask.shape}") + # log_master(logger.debug, f"chunk_causal_mask={chunk_causal_mask}") + + #freqs_ci = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) + for i, layer in enumerate(self.layers): hidden_states = layer( hidden_states, @@ -1030,13 +1128,13 @@ class Llama4TextModel(nn.Module): run_index=self.run_index, ) - if self.run_index == 0: - torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") - log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") + # if self.run_index == 0: + # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") + hidden_states, _ = self.norm(hidden_states) - if self.run_index == 0: - torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") - log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") + # if self.run_index == 0: + # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") + self.run_index += 1 return hidden_states @@ -1050,8 +1148,6 @@ class Llama4TextModel(nn.Module): chunked_attention_mask=None, use_cache=True, ): - print(f"update 11111111111111111") - print(f"self.config._attn_implementation={self.config._attn_implementation}") if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask, attention_mask # flash does not support chunked attn TODO support flash @@ -1060,7 +1156,6 @@ class Llama4TextModel(nn.Module): if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: return None, None - print(f"update 222222222222222222") sequence_length = input_tensor.shape[1] attention_chunk_size = self.config.attention_chunk_size @@ -1272,7 +1367,7 @@ class Llama4ForCausalLM(nn.Module): hpu_attention_meta=hpu_attention_meta, attention_mask=attention_mask, ) - print(f"lm_head_indices={lm_head_indices}") + if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] @@ -1285,10 +1380,10 @@ class Llama4VisionMLP2(torch.nn.Module): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.fc1 = FastLinear.load( + self.fc1 = TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.fc1", weights=weights, bias=False ) - self.fc2 = FastLinear.load( + self.fc2 = TensorParallelRowLinear.load( config=config, prefix=f"{prefix}.fc2", weights=weights, bias=False ) self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] @@ -1296,15 +1391,19 @@ class Llama4VisionMLP2(torch.nn.Module): def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) + torch_save(hidden_states, f"trans.mlp.fc1.hidden_states.pt") hidden_states = self.activation_fn(hidden_states) + torch_save(hidden_states, f"trans.mlp.activation_fn.hidden_states.pt") hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) - return self.activation_fn(self.fc2(hidden_states)) - + torch_save(hidden_states, f"trans.mlp.dropout.hidden_states.pt") + hidden_states = self.fc2(hidden_states) + torch_save(hidden_states, f"trans.mlp.fc2.hidden_states.pt") + return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again class Llama4MultiModalProjector(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.linear_1 = TensorParallelColumnLinear.load( + self.linear_1 = FastLinear.load( config=config, prefix=f"{prefix}.linear_1", weights=weights, bias=False ) @@ -1318,18 +1417,22 @@ def pixel_shuffle(input_tensor, shuffle_ratio): batch_size, num_patches, channels = input_tensor.shape patch_size = int(math.sqrt(num_patches)) + print_0(f"pixel_shuffle: {input_tensor.shape}, patch_size: {patch_size}, shuffle_ratio: {shuffle_ratio}") input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() - + torch_save(input_tensor, f"pixel_shuffle.input_tensor.pt") reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) + torch_save(reshaped_tensor, f"pixel_shuffle.reshaped_tensor.pt") reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - + torch_save(reshaped_tensor, f"pixel_shuffle.permute.reshaped_tensor.pt") reshaped_tensor = reshaped_tensor.view( batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) ) + torch_save(reshaped_tensor, f"pixel_shuffle.final_viewed_tensor.pt") reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) + torch_save(output_tensor, f"pixel_shuffle.output_tensor.pt") return output_tensor @@ -1345,167 +1448,12 @@ class Llama4VisionPixelShuffleMLP(nn.Module): encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) return self.mlp(encoded_patches) +# TODO there is a different RoPE for vision encoder, defined as below +def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): + ndim = query.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] + return freqs_ci.view(*shape) -LLAVA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlavaConfig`] or [`LlavaVisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" -# def reshape_for_broadcast(freqs: torch.Tensor, target: torch.Tensor): -# """Reshape frequency tensor for broadcasting to target tensor.""" -# ndim = target.ndim -# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target.shape)] -# return freqs.view(*shape) -# def reshape_for_broadcast(freqs: torch.Tensor, target: torch.Tensor): -# ndim = target.ndim -# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target.shape)] -# return freqs.view(*shape) - -def reshape_for_broadcast(freqs: torch.Tensor, target): - ndim = len(target) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)] - return freqs.view(*shape) - -def vision_apply_rotary_emb( - query: torch.Tensor, - key: torch.Tensor, - freqs_ci: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - # 调整cos和sin的维度以匹配广播 - cos_emb,sin_emb = freqs_ci.split(1, dim=-1) - # 将query和key的最后一维拆分为二维向量 - query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) - key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) - q_shape = query_reshaped.shape[:-1] - cos_emb = reshape_for_broadcast(cos_emb, q_shape) - sin_emb = reshape_for_broadcast(sin_emb, q_shape) - - # 分离x和y分量 - x_q, y_q = query_reshaped.unbind(-1) - x_k, y_k = key_reshaped.unbind(-1) - # 应用旋转矩阵 - x_q_rot = x_q * cos_emb - y_q * sin_emb - y_q_rot = x_q * sin_emb + y_q * cos_emb - x_k_rot = x_k * cos_emb - y_k * sin_emb - y_k_rot = x_k * sin_emb + y_k * cos_emb - - # 合并结果并恢复形状 - query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) - key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) - return query_out.type_as(query), key_out.type_as(key) - - -# def vision_apply_rotary_emb( -# query: torch.Tensor, -# key: torch.Tensor, -# rotary_emb: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# """ -# Apply rotary position embedding to query and key tensors using floating-point operations. - -# Args: -# query: Query tensor of shape (batch, seq_len, n_heads, head_dim) -# key: Key tensor of shape (batch, seq_len, n_heads, head_dim) -# rotary_emb: Tuple of (cos_theta, sin_theta) tensors from Llama4VisionRotaryEmbedding -# Returns: -# Rotated query and key tensors -# """ -# from habana_frameworks.torch.hpex.kernels import ( -# RotaryPosEmbeddingMode, -# apply_rotary_pos_emb, -# ) -# cos, sin = rotary_emb.split(1, dim=-1) # Unpack cos and sin components -# # # cos_emb = reshape_for_broadcast(cos_theta, query) -# # # sin_emb = reshape_for_broadcast(sin_theta, query) - -# # # 将query和key的最后一维拆分为二维向量 -# # query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) -# # key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) - -# # # 分离x和y分量 -# # x_q, y_q = query_reshaped.unbind(-1) -# # x_k, y_k = key_reshaped.unbind(-1) - -# # # 应用旋转矩阵 -# # x_q_rot = x_q * cos_emb - y_q * sin_emb -# # y_q_rot = x_q * sin_emb + y_q * cos_emb -# # x_k_rot = x_k * cos_emb - y_k * sin_emb -# # y_k_rot = x_k * sin_emb + y_k * cos_emb - -# # # 合并结果并恢复形状 -# # query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) -# # key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) - -# # return query_out.type_as(query), key_out.type_as(key) -# num_tokens = query.shape[0] -# head_size = query.shape[-1] -# # HPU RoPE kernel requires hidden dimension for cos and sin to be equal -# # to query hidden dimension, so the original tensors need to be -# # expanded -# # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE -# # and expansion of cos/sin tensors via concatenation -# print(f"query.shape: {query.shape}, key.shape: {key.shape}") -# print(f"cos.shape: {cos.shape}, sin.shape: {sin.shape}") -# rope_mode = RotaryPosEmbeddingMode.BLOCKWISE -# cos = torch.cat((cos, cos), dim=-1) -# sin = torch.cat((sin, sin), dim=-1) -# rotary_dim = cos.shape[-1] -# query_shape = query.shape -# query = query.reshape(num_tokens, -1, head_size) -# query_rot = query[..., :rotary_dim] -# query_pass = query[..., rotary_dim:] -# query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) -# query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)) - -# key_shape = key.shape -# key = key.reshape(num_tokens, -1, head_size) -# key_rot = key[..., :rotary_dim] -# key_pass = key[..., rotary_dim:] -# key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) -# key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)) -# return query, key - # # Reshape query/key to separate real and imaginary components - # query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) # [..., head_dim//2, 2] - # key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) # [..., head_dim//2, 2] - - # # Reshape cos/sin for broadcasting - # # cos_theta = reshape_for_broadcast(cos_theta, query_reshaped) - # # sin_theta = reshape_for_broadcast(sin_theta, query_reshaped) - - # # Apply rotary transformation (equivalent to complex multiplication) - # # For each pair (x0, x1): [x0*cosθ - x1*sinθ, x0*sinθ + x1*cosθ] - # query_out = torch.stack([ - # query_reshaped[..., 0] * cos_theta - query_reshaped[..., 1] * sin_theta, - # query_reshaped[..., 0] * sin_theta + query_reshaped[..., 1] * cos_theta - # ], dim=-1) - - # key_out = torch.stack([ - # key_reshaped[..., 0] * cos_theta - key_reshaped[..., 1] * sin_theta, - # key_reshaped[..., 0] * sin_theta + key_reshaped[..., 1] * cos_theta - # ], dim=-1) - - # # Restore original shape - # query_out = query_out.flatten(-2) # [batch, seq_len, n_heads, head_dim] - # key_out = key_out.flatten(-2) - - # # Maintain original dtype - # return query_out.type_as(query), key_out.type_as(key) - -# # TODO there is a different RoPE for vision encoder, defined as below -# def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): -# ndim = query.ndim -# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] -# return freqs_ci.view(*shape) # def vision_apply_rotary_emb( @@ -1513,55 +1461,112 @@ def vision_apply_rotary_emb( # key: torch.Tensor, # freqs_ci: torch.Tensor, # ) -> Tuple[torch.Tensor, torch.Tensor]: -# query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) -# key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) -# freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:] -# freqs_ci = freqs_ci.to(query_.device) -# query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) -# key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) -# return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3 +# cos_cache, sin_cache = freqs_ci.chunk(2, dim=-1) +# # shape: [577, 1, 44] +# #print(f"[DENBUG] cos_cache.shape: {cos_cache.shape}, sin_cache.shape: {sin_cache.shape}") +# query_2d = query.float().reshape(*query.shape[:-1], -1, 2) +# key_2d = key.float().reshape(*key.shape[:-1], -1, 2) +# # e.g., [17, 577, 8, 44, 2] +# #print(f'[DEBUG] query_2d.shape: {query_2d.shape}, key_2d.shape: {key_2d.shape}') + +# # Reshape cos_cache and sin_cache to broadcast properly. +# # We want them to have shape [1, 577, 1, 44] to match the query dimensions (except for the last two dims). +# cos_cache = cos_cache.view(1, cos_cache.shape[0], 1, cos_cache.shape[-1]) +# sin_cache = sin_cache.view(1, sin_cache.shape[0], 1, sin_cache.shape[-1]) +# # e.g., [1, 577, 1, 44] + +# # Separate the real and imaginary parts. +# q_real, q_imag = query_2d.unbind(-1) # each: [17, 577, 8, 44] +# k_real, k_imag = key_2d.unbind(-1) # each: [17, 577, 8, 44] + +# # Manually apply the complex multiplication (rotation) using the trigonometric identities. +# # For a complex multiplication: (a+ib)*(c+id) = (ac - bd) + i(ad + bc) +# q_rotated_real = q_real * cos_cache - q_imag * sin_cache +# q_rotated_imag = q_real * sin_cache + q_imag * cos_cache + +# k_rotated_real = k_real * cos_cache - k_imag * sin_cache +# k_rotated_imag = k_real * sin_cache + k_imag * cos_cache + +# # Re-stack the rotated components into a last dimension of size 2. +# q_rotated = torch.stack([q_rotated_real, q_rotated_imag], dim=-1) # shape: [17, 577, 8, 44, 2] +# k_rotated = torch.stack([k_rotated_real, k_rotated_imag], dim=-1) # shape: [17, 577, 8, 44, 2] + +# # Flatten the last two dimensions to match the original output shape. +# # Flatten back to the desired shape (e.g., collapse the last two dimensions). +# query_out = q_rotated.flatten(3) +# key_out = k_rotated.flatten(3) + +# return query_out.type_as(query), key_out.type_as(key) + class Llama4VisionAttention(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads #// weights.process_group.size() + self.num_heads = config.num_attention_heads // weights.process_group.size() self.progress_group = weights.process_group self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = 1 self.attention_dropout = config.attention_dropout - self.q_proj = FastLinear.load( - prefix=f"{prefix}.q_proj", weights=weights, config=config, bias=True + self.q_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=True, ) - self.k_proj = FastLinear.load( - prefix=f"{prefix}.k_proj", weights=weights, config=config, bias=True + self.k_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=True, ) - self.v_proj = FastLinear.load( - prefix=f"{prefix}.v_proj", weights=weights, config=config, bias=True + self.v_proj = TensorParallelColumnLinear.load( + config=config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=True, ) - self.o_proj = FastLinear.load( - prefix=f"{prefix}.o_proj", weights=weights, config=config, bias=True + self.o_proj = TensorParallelRowLinear.load( + config=config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=True, ) + # self.qkv_proj = TensorParallelColumnLinear.load_multi( + # config, + # prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + # dim=0, + # weights=weights, + # bias=True, + # ) + # self.o_proj = TensorParallelRowLinear.load( + # config, + # prefix=f"{prefix}.o_proj", + # weights=weights, + # bias=True, + # ) def forward( self, hidden_states: torch.Tensor, freqs_ci: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex attention_mask: Optional[torch.Tensor] = None, + run_index: Optional[int] = None, + layer_idx: Optional[int] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) + if run_index != -1: + torch_save(hidden_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.input.pt") + query_states = self.q_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) - #qkv = self.qkv_proj(hidden_states) - #print(f"qkv shape: {qkv.shape}") + # qkv = self.qkv_proj(hidden_states) - # if self.process_group.size() > 1: - # torch.distributed.all_reduce(qkv, group=self.process_group) # query_states, key_states, value_states = qkv.split( # [ @@ -1574,17 +1579,43 @@ class Llama4VisionAttention(nn.Module): # query_states = query_states.view(hidden_shape) # key_states = key_states.view(hidden_shape) # value_states = value_states.view(hidden_shape) + #if run_index != -1: + # torch_save(query_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.query_states.pt") + # torch_save(key_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.key_states.pt") + # torch_save(value_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.value_states.pt") + #query_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.query_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + #key_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.key_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + #value_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.value_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + + log_master( + logger.debug, + f"vision query_states.shape: {query_states.shape}, key_states.shape: {key_states.shape}, freqs_ci.shape: {freqs_ci.shape}" + ) + query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) + #if run_index != -1: + #torch_save(query_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.query_states.pt") + #torch_save(key_states, f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.key_states.pt") + #query_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.query_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + #key_states = torch_load(f"trans.{run_index}.layer.{layer_idx}.self_attn.rotary.key_states.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) - query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) #print(f"attention_mask shape: {attention_mask.shape}") - #print(f"attention_mask: {attention_mask}") + print(f"attention_mask: {attention_mask}") + if hasattr(self, "num_key_value_groups"): + print_0(f"module.num_key_value_groups={self.num_key_value_groups}") + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask + query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 ) + + attn_output = attn_output.transpose(1, 2).contiguous() + #attn_output = torch.load(f"trans.{run_index}.layer.{layer_idx}.self_attn.attn_output.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output @@ -1595,10 +1626,10 @@ class Llama4VisionMLP(nn.Module): super().__init__() self.config = config self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] - self.fc1 = FastLinear.load( + self.fc1 = TensorParallelColumnLinear.load( prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True ) - self.fc2 = FastLinear.load( + self.fc2 = TensorParallelRowLinear.load( prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True ) @@ -1634,17 +1665,28 @@ class Llama4VisionEncoderLayer(nn.Module): hidden_state: torch.Tensor, freqs_ci: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + run_index: Optional[int] = None, + layer_idx: Optional[int] = None, ): # Self Attention residual = hidden_state - + if run_index != -1: + torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.input.pt") hidden_state = self.input_layernorm(hidden_state) - + if run_index != -1: + torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.input_norm.pt") + torch_save(attention_mask, f"trans.{run_index}.encoder.layer.{layer_idx}.attention_mask.pt") + torch_save(freqs_ci, f"trans.{run_index}.encoder.layer.{layer_idx}.freqs_ci.pt") hidden_state = self.self_attn( hidden_state, freqs_ci=freqs_ci, attention_mask=attention_mask, + run_index=run_index, + layer_idx=layer_idx, ) + #if run_index != -1: + #torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.atten.pt") + #hidden_state = torch.load(f"trans.{run_index}.encoder.layer.{layer_idx}.atten.pt").to(device=hidden_state.device,dtype=hidden_state.dtype) hidden_state = residual + hidden_state # Feed forward @@ -1653,6 +1695,8 @@ class Llama4VisionEncoderLayer(nn.Module): hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state + if run_index != -1: + torch_save(hidden_state, f"trans.{run_index}.encoder.layer.{layer_idx}.output.pt") outputs = (hidden_state,) @@ -1677,6 +1721,7 @@ class Llama4VisionEncoder(nn.Module): ]) self.gradient_checkpointing = False self.config = config + self.run_index = -1 def forward( self, @@ -1685,16 +1730,21 @@ class Llama4VisionEncoder(nn.Module): attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: - for encoder_layer in self.layers: + for layer_idx, encoder_layer in enumerate(self.layers): layer_outputs = encoder_layer( hidden_state=hidden_states, attention_mask=attention_mask, freqs_ci=freqs_ci, + run_index=self.run_index, + layer_idx=layer_idx, ) hidden_states = layer_outputs[0] - + if self.run_index != -1: + torch_save(hidden_states, f"trans.{self.run_index}.encoder.output.pt") + #hidden_states = torch.load(f"trans.{self.run_index}.encoder.output.pt").to(device=hidden_states.device,dtype=hidden_states.dtype) + self.run_index += 1 return hidden_states @@ -1719,43 +1769,93 @@ class Llama4UnfoldConvolution(nn.Module): hidden_states = self.linear(hidden_states) return hidden_states +# class Llama4VisionRotaryEmbedding(nn.Module): +# def __init__(self, config, weights): +# super().__init__() +# idx = config.image_size // config.patch_size +# print_0(f"VisionRotaryEmbedding idx: {idx}") +# img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) +# img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) +# print_0(f"VisionRotaryEmbedding img_idx: {img_idx.shape}") +# torch_save(img_idx, f"trans.vision.img_idx.pt") +# img_idx[-1, -1] = -2 # ID_CLS_TOKEN +# print_0(f"VisionRotaryEmbedding img_idx: {img_idx}, img_idx.dtype: {img_idx.dtype}") +# frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x +# torch_save(frequencies_x, f"trans.vision.frequencies_x.pt") +# frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y +# print_0(f"VisionRotaryEmbedding frequencies_y: {frequencies_y}") +# torch_save(frequencies_y, f"trans.vision.frequencies_y.pt") +# freq_dim = config.hidden_size // config.num_attention_heads // 2 +# rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) +# torch_save(rope_freq, f"trans.vision.rope_freq.pt") +# freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) +# torch_save(freqs_x, f"trans.vision.freqs_x.pt") +# freqs_x = freqs_x.repeat_interleave(2, dim=-1) +# torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt") +# freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) +# torch_save(freqs_y, f"trans.vision.freqs_y.pt") +# freqs_y = freqs_y.repeat_interleave(2, dim=-1) +# torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt") + +# freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] +# freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) +# torch_save(freqs, f"trans.vision.freqs.pt") +# #freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) +# freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) +# self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 + +# def forward(self, hidden_states): +# return self.freqs_ci + + class Llama4VisionRotaryEmbedding(nn.Module): def __init__(self, config, weights): super().__init__() # Calculate image grid indices idx = config.image_size // config.patch_size + print_0(f"VisionRotaryEmbedding idx: {idx}") img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + torch_save(img_idx, f"trans.vision.img_idx.pt") + img_idx[-1, -1] = -2 # ID_CLS_TOKEN - + print_0(f"VisionRotaryEmbedding img_idx: {img_idx}, img_idx.dtype: {img_idx.dtype}") # Calculate x and y coordinates frequencies_x = img_idx % idx # x coordinates - frequencies_y = img_idx // idx # y coordinates + torch_save(frequencies_x, f"trans.vision.frequencies_x.pt") + frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates + print_0(f"VisionRotaryEmbedding frequencies_y: {frequencies_y}") + torch_save(frequencies_y, f"trans.vision.frequencies_y.pt") # Calculate frequency components freq_dim = config.hidden_size // config.num_attention_heads // 2 rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim)) + torch_save(rope_freq, f"trans.vision.rope_freq.pt") # Compute frequencies for x and y directions - freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) + torch_save(freqs_x, f"trans.vision.freqs_x.pt") + freqs_x = freqs_x.repeat_interleave(2, dim=-1) + torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt") + freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) + torch_save(freqs_y, f"trans.vision.freqs_y.pt") + freqs_y = freqs_y.repeat_interleave(2, dim=-1) + torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt") # Combine frequencies and mask special tokens freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + torch_save(freqs, f"trans.vision.freqs.pt") - # Store cosθ and sinθ separately instead of complex numbers - cos_freq = torch.cos(freqs) - sin_freq = torch.sin(freqs) - self.freqs_ci = torch.stack([cos_freq, sin_freq], dim=-1).to(weights.dtype) - # # Store sequence length for validation - # self.seq_len = idx**2 + 1 # +1 for CLS token - # print(f"self.seq_len: {self.seq_len}, freqs shape: {freqs.shape}") + #freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + #freq_cis = torch.concat([torch.cos(freqs), torch.sin(freqs)], dim=-1) + self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 def forward(self, hidden_states): """ Returns the rotary embedding components (cosθ, sinθ) for the given hidden states """ - return self.freqs_ci + return self.freqs_ci.to(dtype=hidden_states.dtype, device=hidden_states.device) class Llama4VisionModel(nn.Module): @@ -1782,7 +1882,11 @@ class Llama4VisionModel(nn.Module): self.positional_embedding_vlm = nn.Parameter( weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False ) - + + log_master( + logger.debug, + f"vision positional_embedding_vlm.shape: {self.positional_embedding_vlm.shape}" + ) self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) # layer norms @@ -1800,6 +1904,7 @@ class Llama4VisionModel(nn.Module): self.vision_adapter = Llama4VisionPixelShuffleMLP( prefix=f"{prefix}.vision_adapter", config=config, weights=weights ) + self.run_index = -1 def forward( self, @@ -1807,13 +1912,18 @@ class Llama4VisionModel(nn.Module): attention_mask: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ): - + if self.run_index != -1: + torch_save(pixel_values, f"trans.{self.run_index}.vision.pixel_values.pt") + # num_concurrent_media and num_chunks are both currently 1 batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape num_concurrent_media = 1 num_chunks = 1 hidden_state = self.patch_embedding(pixel_values) _, num_patches, hidden_dim = hidden_state.shape + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.patch.pt") + # Add cls token hidden_state = hidden_state.reshape( @@ -1822,33 +1932,48 @@ class Llama4VisionModel(nn.Module): class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 - + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.class.pt") # Position embeddings hidden_state = hidden_state.reshape( batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim ) positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) hidden_state = hidden_state + positional_embedding + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.position.pt") hidden_state = self.layernorm_pre(hidden_state) + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.layernorm_pre.pt") hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) freqs_ci = self.rotary_embedding(pixel_values) - + if self.run_index != -1: + torch_save(freqs_ci, f"trans.{self.run_index}.vision.freqs_ci.pt") + hidden_state = self.model( hidden_state, attention_mask=None, freqs_ci=freqs_ci, ) + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.model.pt") + hidden_state = self.layernorm_post(hidden_state) + if self.run_index != -1: + torch_save(hidden_state, f"trans.{self.run_index}.vision.post.pt") hidden_state = hidden_state[:, :-1, :] # now, we use Llama4VisionPixelShuffle + mlp to project embeddings hidden_state = self.vision_adapter(hidden_state) - + #if self.run_index != -1: + #hidden_state = torch.load(f"trans.{self.run_index}.vision.hidden_states.pt").to(device=hidden_state.device,dtype=hidden_state.dtype) + #torch_save(hidden_state, f"trans.{self.run_index}.vision.hidden_states.pt") + self.run_index += 1 return hidden_state class Llama4ForConditionalGeneration(nn.Module): @@ -1861,28 +1986,31 @@ class Llama4ForConditionalGeneration(nn.Module): config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator config.text_config._attn_implementation = None - + log_master( + logger.debug, + f"init Llama4ForConditionalGeneration with config!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) self.vision_model = Llama4VisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.multi_modal_projector = Llama4MultiModalProjector( prefix="multi_modal_projector", config=config, weights=weights ) - synchronize(weights.device) - real_free_memory = get_free_memory(weights.device, 1) - log_master( - logger.debug, - f"Free memory real: {real_free_memory / 1e9:.2f}GB" - ) + # synchronize(weights.device) + # real_free_memory = get_free_memory(weights.device, 1) + # log_master( + # logger.debug, + # f"Free memory real: {real_free_memory / 1e9:.2f}GB" + # ) self.text_model = Llama4ForCausalLM( prefix="language_model", config=config.text_config, weights=weights @@ -1941,10 +2069,10 @@ class Llama4ForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, **lm_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - log_master( - logger.debug, - f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" - ) + # log_master( + # logger.debug, + # f"input_ids: {input_ids}, shape = {input_ids.shape}, input_ids={input_ids[-20:]}" + # ) def _get_padding_mask(input_ids, pad_token_id=0): return (input_ids != pad_token_id).long() # 非填充位置为1,填充位置为0 @@ -1952,7 +2080,7 @@ class Llama4ForConditionalGeneration(nn.Module): # 示例 attention_mask = _get_padding_mask(input_ids) attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1) - log_master(logger.debug,f"attention_mask={attention_mask}") + #log_master(logger.debug,f"attention_mask={attention_mask}") inputs_embeds = self.text_model.model.embed_tokens(input_ids) vision_feature_layer = ( vision_feature_layer diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index a8f3591f..b99fea31 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -338,6 +338,9 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): image_id += 1 full_text = image_text_replacement_fixup(config, full_text) + log_master( + logger.debug, f"full_text: {full_text}" + ) input_ids = tokenizer( full_text, truncation=True,