mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Clean the code
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
f0dac1dec8
commit
3aa882337e
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -24,16 +25,12 @@ import torch.nn.functional as F
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
)
|
||||
|
||||
import habana_frameworks.torch as htorch
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
@ -41,58 +38,19 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
TensorParallelAdapterRowLinear
|
||||
)
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
get_kv_scales,
|
||||
paged_attention,
|
||||
attention,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
load_attention,
|
||||
FlashLlamaAttention,
|
||||
FlashLlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
from text_generation_server.utils.import_utils import (
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
|
||||
_CONFIG_FOR_DOC = "Llama4Config"
|
||||
def print_0(*args, **kwargs):
|
||||
"""
|
||||
Only print on rank 0 in distributed training.
|
||||
Works like built-in print() function but only executes on rank 0.
|
||||
"""
|
||||
# 检查是否处于分布式环境
|
||||
if torch.distributed.is_initialized():
|
||||
# 获取当前rank
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
else:
|
||||
# 如果不是分布式环境,正常打印
|
||||
print(*args, **kwargs, flush=True)
|
||||
|
||||
def torch_save(tensor, name):
|
||||
pass
|
||||
# Only save on the main process (rank 0) when using distributed training
|
||||
# if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
# torch.save(tensor, name)
|
||||
def torch_load(name):
|
||||
rank = torch.distributed.get_rank()
|
||||
return torch.load(f"{name}.{rank}")
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs: torch.Tensor, target):
|
||||
@ -218,8 +176,6 @@ class Llama4TextMLP(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class Llama4TextL2Norm(torch.nn.Module):
|
||||
def __init__(self, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -235,26 +191,6 @@ class Llama4TextL2Norm(torch.nn.Module):
|
||||
return f"eps={self.eps}"
|
||||
|
||||
|
||||
class Llama4TextRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
"""
|
||||
Llama4RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = config.rms_norm_eps
|
||||
self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"), requires_grad=False)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
class Llama4TextMoe(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -351,78 +287,14 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn_scale = config.attn_scale
|
||||
self.floor_scale = config.floor_scale
|
||||
self.attn_temperature_tuning = config.attn_temperature_tuning
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
|
||||
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
self.softmax_scale = getattr(
|
||||
config, "attention_multiplier", self.head_dim**-0.5
|
||||
)
|
||||
|
||||
if self.num_attention_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_attention_heads` must be divisible by `num_shards` (got `num_attention_heads`: {self.num_attention_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
if config.num_key_value_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_attention_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
#self.query_key_value = load_attention(config, prefix, weights, layer_idx)
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.v_proj",
|
||||
weights=weights,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
|
||||
# self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
# o_proj,
|
||||
# layer_idx,
|
||||
# "o_proj",
|
||||
# process_group=weights.process_group,
|
||||
# )
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
if self.config.use_qk_norm and self.use_rope:
|
||||
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
|
||||
|
||||
@ -442,29 +314,19 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
bs = seqlen.input_lengths.shape[0]
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
#qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
# query_states, kv_states = qkv.split(
|
||||
# [
|
||||
# self.head_size * self.num_heads,
|
||||
# 2 * self.head_size * self.num_key_value_heads,
|
||||
# ],
|
||||
# dim=-1,
|
||||
# )
|
||||
# query_states, key_states, value_states = qkv.split(
|
||||
# [
|
||||
# self.head_size * self.num_heads,
|
||||
# self.head_size * self.num_key_value_heads,
|
||||
# self.head_size * self.num_key_value_heads,
|
||||
# ],
|
||||
# dim=-1,
|
||||
# )
|
||||
query_states = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_dim * self.num_heads,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape)
|
||||
|
||||
if self.use_rope: # the 16E model skips rope for long context on certain layers
|
||||
query_states, key_states = apply_rotary_emb(
|
||||
@ -475,20 +337,13 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
query_states = self.qk_norm(query_states)
|
||||
key_states = self.qk_norm(key_states)
|
||||
|
||||
|
||||
|
||||
# query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||
# key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
# value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# query_states = query_states.transpose(1, 2)
|
||||
# key_states = key_states.transpose(1, 2)
|
||||
kv_cache.store(
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
attn_scales = (
|
||||
@ -500,16 +355,6 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# sdpa
|
||||
# attn_output = attention(
|
||||
# query=query_states,
|
||||
# key=key_states,
|
||||
# value=value_states,
|
||||
# kv_scales=self.kv_scales,
|
||||
# kv_cache=kv_cache,
|
||||
# seqlen=seqlen,
|
||||
# softmax_scale=self.softmax_scale,
|
||||
# causal=True
|
||||
# )
|
||||
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@ -549,7 +394,7 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = self.o_proj(attn_output, adapter_data)
|
||||
return attn_output
|
||||
|
||||
|
||||
@ -565,18 +410,16 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
else:
|
||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
|
||||
|
||||
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
||||
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
||||
# self.input_layernorm = FastRMSNorm.load(
|
||||
# prefix=f"{prefix}.input_layernorm",
|
||||
# weights=weights,
|
||||
# eps=config.rms_norm_eps,
|
||||
# )
|
||||
# self.post_attention_layernorm = FastRMSNorm.load(
|
||||
# prefix=f"{prefix}.post_attention_layernorm",
|
||||
# weights=weights,
|
||||
# eps=config.rms_norm_eps,
|
||||
# )
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -593,7 +436,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
|
||||
# use local attention mask for ROPE layers
|
||||
if self.use_chunked_attention and chunk_causal_mask is not None:
|
||||
@ -617,7 +460,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
||||
hidden_states = residual + hidden_states.view(residual.shape)
|
||||
return hidden_states
|
||||
@ -945,13 +788,9 @@ class Llama4VisionMLP2(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
torch_save(hidden_states, f"trans.mlp.fc1.hidden_states.pt")
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
torch_save(hidden_states, f"trans.mlp.activation_fn.hidden_states.pt")
|
||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
torch_save(hidden_states, f"trans.mlp.dropout.hidden_states.pt")
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
torch_save(hidden_states, f"trans.mlp.fc2.hidden_states.pt")
|
||||
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again
|
||||
|
||||
class Llama4MultiModalProjector(nn.Module):
|
||||
@ -973,19 +812,14 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
|
||||
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
||||
batch_size, height, width, channels = input_tensor.size()
|
||||
torch_save(input_tensor, f"pixel_shuffle.input_tensor.pt")
|
||||
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
|
||||
torch_save(reshaped_tensor, f"pixel_shuffle.reshaped_tensor.pt")
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
torch_save(reshaped_tensor, f"pixel_shuffle.permute.reshaped_tensor.pt")
|
||||
reshaped_tensor = reshaped_tensor.view(
|
||||
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
|
||||
)
|
||||
torch_save(reshaped_tensor, f"pixel_shuffle.final_viewed_tensor.pt")
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
|
||||
torch_save(output_tensor, f"pixel_shuffle.output_tensor.pt")
|
||||
return output_tensor
|
||||
|
||||
|
||||
@ -1019,30 +853,6 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
# self.k_proj = TensorParallelColumnLinear.load(
|
||||
# config=config,
|
||||
# prefix=f"{prefix}.k_proj",
|
||||
# weights=weights,
|
||||
# bias=True,
|
||||
# )
|
||||
# self.v_proj = TensorParallelColumnLinear.load(
|
||||
# config=config,
|
||||
# prefix=f"{prefix}.v_proj",
|
||||
# weights=weights,
|
||||
# bias=True,
|
||||
# )
|
||||
# self.o_proj = TensorParallelRowLinear.load(
|
||||
# config=config,
|
||||
# prefix=f"{prefix}.o_proj",
|
||||
# weights=weights,
|
||||
# bias=True,
|
||||
# )
|
||||
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
@ -1066,9 +876,6 @@ class Llama4VisionAttention(nn.Module):
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
# key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
# value_states = self.v_proj(hidden_states).view(hidden_shape)
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
|
||||
|
||||
@ -1090,10 +897,6 @@ class Llama4VisionAttention(nn.Module):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# if hasattr(self, "num_key_value_groups"):
|
||||
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0
|
||||
@ -1217,9 +1020,6 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
|
||||
# self.linear = TensorParallelColumnLinear.load(
|
||||
# config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
|
||||
# )
|
||||
self.linear = FastLinear.load(
|
||||
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
|
||||
)
|
||||
@ -1239,37 +1039,26 @@ class Llama4VisionRotaryEmbedding(nn.Module):
|
||||
idx = config.image_size // config.patch_size
|
||||
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
torch_save(img_idx, f"trans.vision.img_idx.pt")
|
||||
|
||||
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
||||
# Calculate x and y coordinates
|
||||
frequencies_x = img_idx % idx # x coordinates
|
||||
torch_save(frequencies_x, f"trans.vision.frequencies_x.pt")
|
||||
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates
|
||||
torch_save(frequencies_y, f"trans.vision.frequencies_y.pt")
|
||||
# Calculate frequency components
|
||||
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
||||
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim))
|
||||
torch_save(rope_freq, f"trans.vision.rope_freq.pt")
|
||||
|
||||
# Compute frequencies for x and y directions
|
||||
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :])
|
||||
torch_save(freqs_x, f"trans.vision.freqs_x.pt")
|
||||
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
|
||||
torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt")
|
||||
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :])
|
||||
torch_save(freqs_y, f"trans.vision.freqs_y.pt")
|
||||
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
|
||||
torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt")
|
||||
|
||||
# Combine frequencies and mask special tokens
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
||||
torch_save(freqs, f"trans.vision.freqs.pt")
|
||||
|
||||
#freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||
#freq_cis = torch.concat([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -1304,10 +1093,6 @@ class Llama4VisionModel(nn.Module):
|
||||
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False
|
||||
)
|
||||
|
||||
log_master(
|
||||
logger.debug,
|
||||
f"vision positional_embedding_vlm.shape: {self.positional_embedding_vlm.shape}"
|
||||
)
|
||||
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
|
||||
|
||||
# layer norms
|
||||
@ -1507,4 +1292,4 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
attention_mask
|
||||
)
|
||||
|
||||
return logits, speculative_logits
|
||||
return logits, speculative_logits
|
Loading…
Reference in New Issue
Block a user