Clean the code

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 17:53:26 +00:00
parent f0dac1dec8
commit 3aa882337e

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -24,16 +25,12 @@ import torch.nn.functional as F
from transformers import Llama4TextConfig from transformers import Llama4TextConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput, BaseModelOutput,
) )
import habana_frameworks.torch as htorch
from transformers.processing_utils import Unpack
from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -41,58 +38,19 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
SpeculativeHead, SpeculativeHead,
FastLinear, FastLinear,
TensorParallelAdapterRowLinear
) )
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
KVCache, KVCache,
get_kv_scales,
paged_attention, paged_attention,
attention,
Seqlen, Seqlen,
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
load_attention, load_attention,
FlashLlamaAttention, FlashLlamaAttention,
FlashLlamaForCausalLM,
LlamaMLP, LlamaMLP,
) )
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from loguru import logger
from text_generation_server.utils.log import log_master
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
_CONFIG_FOR_DOC = "Llama4Config"
def print_0(*args, **kwargs):
"""
Only print on rank 0 in distributed training.
Works like built-in print() function but only executes on rank 0.
"""
# 检查是否处于分布式环境
if torch.distributed.is_initialized():
# 获取当前rank
if torch.distributed.get_rank() == 0:
print(*args, **kwargs)
else:
# 如果不是分布式环境,正常打印
print(*args, **kwargs, flush=True)
def torch_save(tensor, name):
pass
# Only save on the main process (rank 0) when using distributed training
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
# torch.save(tensor, name)
def torch_load(name):
rank = torch.distributed.get_rank()
return torch.load(f"{name}.{rank}")
def reshape_for_broadcast(freqs: torch.Tensor, target): def reshape_for_broadcast(freqs: torch.Tensor, target):
@ -218,8 +176,6 @@ class Llama4TextMLP(nn.Module):
) )
class Llama4TextL2Norm(torch.nn.Module): class Llama4TextL2Norm(torch.nn.Module):
def __init__(self, eps: float = 1e-6): def __init__(self, eps: float = 1e-6):
super().__init__() super().__init__()
@ -235,26 +191,6 @@ class Llama4TextL2Norm(torch.nn.Module):
return f"eps={self.eps}" return f"eps={self.eps}"
class Llama4TextRMSNorm(nn.Module):
def __init__(self, prefix, config, weights):
"""
Llama4RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.eps = config.rms_norm_eps
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"), requires_grad=False)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Llama4TextMoe(nn.Module): class Llama4TextMoe(nn.Module):
def __init__( def __init__(
self, self,
@ -351,78 +287,14 @@ class Llama4TextAttention(FlashLlamaAttention):
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_attention_heads = config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.num_key_value_heads = config.num_key_value_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn_scale = config.attn_scale self.attn_scale = config.attn_scale
self.floor_scale = config.floor_scale self.floor_scale = config.floor_scale
self.attn_temperature_tuning = config.attn_temperature_tuning self.attn_temperature_tuning = config.attn_temperature_tuning
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.is_causal = True
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_dim**-0.5
)
if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` (got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
if config.num_key_value_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_attention_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
#self.query_key_value = load_attention(config, prefix, weights, layer_idx)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.q_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
)
self.k_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
)
self.v_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
)
self.o_proj = TensorParallelRowLinear.load(
config=config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
)
# self.o_proj = TensorParallelAdapterRowLinear.load(
# o_proj,
# layer_idx,
# "o_proj",
# process_group=weights.process_group,
# )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
if self.config.use_qk_norm and self.use_rope: if self.config.use_qk_norm and self.use_rope:
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
@ -442,29 +314,19 @@ class Llama4TextAttention(FlashLlamaAttention):
bs = seqlen.input_lengths.shape[0] bs = seqlen.input_lengths.shape[0]
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
#qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
# query_states, kv_states = qkv.split( query_states, key_states, value_states = qkv.split(
# [ [
# self.head_size * self.num_heads, self.head_dim * self.num_heads,
# 2 * self.head_size * self.num_key_value_heads, self.head_dim * self.num_key_value_heads,
# ], self.head_dim * self.num_key_value_heads,
# dim=-1, ],
# ) dim=-1,
# query_states, key_states, value_states = qkv.split( )
# [
# self.head_size * self.num_heads,
# self.head_size * self.num_key_value_heads,
# self.head_size * self.num_key_value_heads,
# ],
# dim=-1,
# )
query_states = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
# query_states = query_states.view(-1, self.num_heads, self.head_size) query_states = query_states.view(hidden_shape)
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) key_states = key_states.view(hidden_shape)
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) value_states = value_states.view(hidden_shape)
if self.use_rope: # the 16E model skips rope for long context on certain layers if self.use_rope: # the 16E model skips rope for long context on certain layers
query_states, key_states = apply_rotary_emb( query_states, key_states = apply_rotary_emb(
@ -475,20 +337,13 @@ class Llama4TextAttention(FlashLlamaAttention):
query_states = self.qk_norm(query_states) query_states = self.qk_norm(query_states)
key_states = self.qk_norm(key_states) key_states = self.qk_norm(key_states)
# query_states = query_states.view(-1, self.num_heads, self.head_size)
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
# query_states = query_states.transpose(1, 2)
# key_states = key_states.transpose(1, 2)
kv_cache.store( kv_cache.store(
key=key_states, key=key_states,
value=value_states, value=value_states,
slots=slots, slots=slots,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
) )
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
if self.attn_temperature_tuning and not self.use_rope: if self.attn_temperature_tuning and not self.use_rope:
attn_scales = ( attn_scales = (
@ -500,16 +355,6 @@ class Llama4TextAttention(FlashLlamaAttention):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# sdpa # sdpa
# attn_output = attention(
# query=query_states,
# key=key_states,
# value=value_states,
# kv_scales=self.kv_scales,
# kv_cache=kv_cache,
# seqlen=seqlen,
# softmax_scale=self.softmax_scale,
# causal=True
# )
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@ -549,7 +394,7 @@ class Llama4TextAttention(FlashLlamaAttention):
) )
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output, adapter_data)
return attn_output return attn_output
@ -565,18 +410,16 @@ class Llama4TextDecoderLayer(nn.Module):
else: else:
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load(
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) prefix=f"{prefix}.input_layernorm",
# self.input_layernorm = FastRMSNorm.load( weights=weights,
# prefix=f"{prefix}.input_layernorm", eps=config.rms_norm_eps,
# weights=weights, )
# eps=config.rms_norm_eps, self.post_attention_layernorm = FastRMSNorm.load(
# ) prefix=f"{prefix}.post_attention_layernorm",
# self.post_attention_layernorm = FastRMSNorm.load( weights=weights,
# prefix=f"{prefix}.post_attention_layernorm", eps=config.rms_norm_eps,
# weights=weights, )
# eps=config.rms_norm_eps,
# )
def forward( def forward(
self, self,
@ -593,7 +436,7 @@ class Llama4TextDecoderLayer(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.input_layernorm(hidden_states)
# use local attention mask for ROPE layers # use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None: if self.use_chunked_attention and chunk_causal_mask is not None:
@ -617,7 +460,7 @@ class Llama4TextDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = self.feed_forward(hidden_states, adapter_data)
hidden_states = residual + hidden_states.view(residual.shape) hidden_states = residual + hidden_states.view(residual.shape)
return hidden_states return hidden_states
@ -945,13 +788,9 @@ class Llama4VisionMLP2(torch.nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
torch_save(hidden_states, f"trans.mlp.fc1.hidden_states.pt")
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
torch_save(hidden_states, f"trans.mlp.activation_fn.hidden_states.pt")
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
torch_save(hidden_states, f"trans.mlp.dropout.hidden_states.pt")
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
torch_save(hidden_states, f"trans.mlp.fc2.hidden_states.pt")
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again
class Llama4MultiModalProjector(nn.Module): class Llama4MultiModalProjector(nn.Module):
@ -973,19 +812,14 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size() batch_size, height, width, channels = input_tensor.size()
torch_save(input_tensor, f"pixel_shuffle.input_tensor.pt")
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
torch_save(reshaped_tensor, f"pixel_shuffle.reshaped_tensor.pt")
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
torch_save(reshaped_tensor, f"pixel_shuffle.permute.reshaped_tensor.pt")
reshaped_tensor = reshaped_tensor.view( reshaped_tensor = reshaped_tensor.view(
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
) )
torch_save(reshaped_tensor, f"pixel_shuffle.final_viewed_tensor.pt")
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
torch_save(output_tensor, f"pixel_shuffle.output_tensor.pt")
return output_tensor return output_tensor
@ -1019,30 +853,6 @@ class Llama4VisionAttention(nn.Module):
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = 1 self.num_key_value_groups = 1
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.q_proj = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=True,
)
# self.k_proj = TensorParallelColumnLinear.load(
# config=config,
# prefix=f"{prefix}.k_proj",
# weights=weights,
# bias=True,
# )
# self.v_proj = TensorParallelColumnLinear.load(
# config=config,
# prefix=f"{prefix}.v_proj",
# weights=weights,
# bias=True,
# )
# self.o_proj = TensorParallelRowLinear.load(
# config=config,
# prefix=f"{prefix}.o_proj",
# weights=weights,
# bias=True,
# )
self.qkv_proj = TensorParallelColumnLinear.load_multi( self.qkv_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -1066,9 +876,6 @@ class Llama4VisionAttention(nn.Module):
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
# query_states = self.q_proj(hidden_states).view(hidden_shape)
# key_states = self.k_proj(hidden_states).view(hidden_shape)
# value_states = self.v_proj(hidden_states).view(hidden_shape)
qkv = self.qkv_proj(hidden_states) qkv = self.qkv_proj(hidden_states)
@ -1090,10 +897,6 @@ class Llama4VisionAttention(nn.Module):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
# if hasattr(self, "num_key_value_groups"):
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0
@ -1217,9 +1020,6 @@ class Llama4UnfoldConvolution(nn.Module):
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
# self.linear = TensorParallelColumnLinear.load(
# config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
# )
self.linear = FastLinear.load( self.linear = FastLinear.load(
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
) )
@ -1239,37 +1039,26 @@ class Llama4VisionRotaryEmbedding(nn.Module):
idx = config.image_size // config.patch_size idx = config.image_size // config.patch_size
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
torch_save(img_idx, f"trans.vision.img_idx.pt")
img_idx[-1, -1] = -2 # ID_CLS_TOKEN img_idx[-1, -1] = -2 # ID_CLS_TOKEN
# Calculate x and y coordinates # Calculate x and y coordinates
frequencies_x = img_idx % idx # x coordinates frequencies_x = img_idx % idx # x coordinates
torch_save(frequencies_x, f"trans.vision.frequencies_x.pt")
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates
torch_save(frequencies_y, f"trans.vision.frequencies_y.pt")
# Calculate frequency components # Calculate frequency components
freq_dim = config.hidden_size // config.num_attention_heads // 2 freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim)) rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim))
torch_save(rope_freq, f"trans.vision.rope_freq.pt")
# Compute frequencies for x and y directions # Compute frequencies for x and y directions
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :])
torch_save(freqs_x, f"trans.vision.freqs_x.pt")
freqs_x = freqs_x.repeat_interleave(2, dim=-1) freqs_x = freqs_x.repeat_interleave(2, dim=-1)
torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt")
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :])
torch_save(freqs_y, f"trans.vision.freqs_y.pt")
freqs_y = freqs_y.repeat_interleave(2, dim=-1) freqs_y = freqs_y.repeat_interleave(2, dim=-1)
torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt")
# Combine frequencies and mask special tokens # Combine frequencies and mask special tokens
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
torch_save(freqs, f"trans.vision.freqs.pt")
#freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
#freq_cis = torch.concat([torch.cos(freqs), torch.sin(freqs)], dim=-1)
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
def forward(self, hidden_states): def forward(self, hidden_states):
@ -1304,10 +1093,6 @@ class Llama4VisionModel(nn.Module):
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False
) )
log_master(
logger.debug,
f"vision positional_embedding_vlm.shape: {self.positional_embedding_vlm.shape}"
)
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
# layer norms # layer norms
@ -1507,4 +1292,4 @@ class Llama4ForConditionalGeneration(nn.Module):
attention_mask attention_mask
) )
return logits, speculative_logits return logits, speculative_logits