diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index de8b8955..2f9b5aeb 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -1,5 +1,6 @@ # coding=utf-8 -# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,16 +25,12 @@ import torch.nn.functional as F from transformers import Llama4TextConfig from transformers.cache_utils import Cache from transformers.activations import ACT2FN -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_outputs import ( BaseModelOutput, ) -import habana_frameworks.torch as htorch -from transformers.processing_utils import Unpack from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers import ( TensorParallelColumnLinear, @@ -41,58 +38,19 @@ from text_generation_server.layers import ( TensorParallelRowLinear, SpeculativeHead, FastLinear, - TensorParallelAdapterRowLinear ) from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.attention import ( KVCache, - get_kv_scales, paged_attention, - attention, Seqlen, HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( load_attention, FlashLlamaAttention, - FlashLlamaForCausalLM, LlamaMLP, ) -from habana_frameworks.torch.hpex.kernels import FusedSDPA -from vllm_hpu_extension.utils import ModuleFusedSDPA -from text_generation_server.utils.import_utils import ( - synchronize, - get_free_memory, -) - -from loguru import logger -from text_generation_server.utils.log import log_master -from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer - -_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" -_CONFIG_FOR_DOC = "Llama4Config" -def print_0(*args, **kwargs): - """ - Only print on rank 0 in distributed training. - Works like built-in print() function but only executes on rank 0. - """ - # 检查是否处于分布式环境 - if torch.distributed.is_initialized(): - # 获取当前rank - if torch.distributed.get_rank() == 0: - print(*args, **kwargs) - else: - # 如果不是分布式环境,正常打印 - print(*args, **kwargs, flush=True) - -def torch_save(tensor, name): - pass - # Only save on the main process (rank 0) when using distributed training - # if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - # torch.save(tensor, name) -def torch_load(name): - rank = torch.distributed.get_rank() - return torch.load(f"{name}.{rank}") def reshape_for_broadcast(freqs: torch.Tensor, target): @@ -218,8 +176,6 @@ class Llama4TextMLP(nn.Module): ) - - class Llama4TextL2Norm(torch.nn.Module): def __init__(self, eps: float = 1e-6): super().__init__() @@ -235,26 +191,6 @@ class Llama4TextL2Norm(torch.nn.Module): return f"eps={self.eps}" -class Llama4TextRMSNorm(nn.Module): - def __init__(self, prefix, config, weights): - """ - Llama4RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.eps = config.rms_norm_eps - self.weight = nn.Parameter(weights.get_tensor(f"{prefix}.weight"), requires_grad=False) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" - - class Llama4TextMoe(nn.Module): def __init__( self, @@ -351,78 +287,14 @@ class Llama4TextAttention(FlashLlamaAttention): self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_attention_heads = config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attn_scale = config.attn_scale self.floor_scale = config.floor_scale self.attn_temperature_tuning = config.attn_temperature_tuning self.attention_dropout = config.attention_dropout - self.is_causal = True self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - # `config.attention_multiplier` is used in Granite - self.softmax_scale = getattr( - config, "attention_multiplier", self.head_dim**-0.5 - ) - - if self.num_attention_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_attention_heads` must be divisible by `num_shards` (got `num_attention_heads`: {self.num_attention_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - if config.num_key_value_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_attention_heads // weights.process_group.size() - self.num_key_value_heads = ( - config.num_key_value_heads // weights.process_group.size() - ) - - #self.query_key_value = load_attention(config, prefix, weights, layer_idx) - - self.kv_scales = get_kv_scales(weights, f"{prefix}") - self.q_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=getattr(config, "attention_bias", False), - ) - self.k_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=getattr(config, "attention_bias", False), - ) - self.v_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=getattr(config, "attention_bias", False), - ) - - self.o_proj = TensorParallelRowLinear.load( - config=config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=getattr(config, "attention_bias", False), - ) - - # self.o_proj = TensorParallelAdapterRowLinear.load( - # o_proj, - # layer_idx, - # "o_proj", - # process_group=weights.process_group, - # ) - - self.num_groups = self.num_heads // self.num_key_value_heads - self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) - if self.config.use_qk_norm and self.use_rope: self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) @@ -442,29 +314,19 @@ class Llama4TextAttention(FlashLlamaAttention): bs = seqlen.input_lengths.shape[0] input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - #qkv = self.query_key_value(hidden_states, adapter_data) - # query_states, kv_states = qkv.split( - # [ - # self.head_size * self.num_heads, - # 2 * self.head_size * self.num_key_value_heads, - # ], - # dim=-1, - # ) - # query_states, key_states, value_states = qkv.split( - # [ - # self.head_size * self.num_heads, - # self.head_size * self.num_key_value_heads, - # self.head_size * self.num_key_value_heads, - # ], - # dim=-1, - # ) - query_states = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + qkv = self.query_key_value(hidden_states, adapter_data) + query_states, key_states, value_states = qkv.split( + [ + self.head_dim * self.num_heads, + self.head_dim * self.num_key_value_heads, + self.head_dim * self.num_key_value_heads, + ], + dim=-1, + ) - # query_states = query_states.view(-1, self.num_heads, self.head_size) - # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) + query_states = query_states.view(hidden_shape) + key_states = key_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) if self.use_rope: # the 16E model skips rope for long context on certain layers query_states, key_states = apply_rotary_emb( @@ -475,20 +337,13 @@ class Llama4TextAttention(FlashLlamaAttention): query_states = self.qk_norm(query_states) key_states = self.qk_norm(key_states) - - - # query_states = query_states.view(-1, self.num_heads, self.head_size) - # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) - # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) - - # query_states = query_states.transpose(1, 2) - # key_states = key_states.transpose(1, 2) kv_cache.store( key=key_states, value=value_states, slots=slots, kv_scales=self.kv_scales, ) + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers if self.attn_temperature_tuning and not self.use_rope: attn_scales = ( @@ -500,16 +355,6 @@ class Llama4TextAttention(FlashLlamaAttention): # Prefill if cu_seqlen_prefill is not None: # sdpa - # attn_output = attention( - # query=query_states, - # key=key_states, - # value=value_states, - # kv_scales=self.kv_scales, - # kv_cache=kv_cache, - # seqlen=seqlen, - # softmax_scale=self.softmax_scale, - # causal=True - # ) query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -549,7 +394,7 @@ class Llama4TextAttention(FlashLlamaAttention): ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, adapter_data) return attn_output @@ -565,18 +410,16 @@ class Llama4TextDecoderLayer(nn.Module): else: self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) - self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights) - self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights) - # self.input_layernorm = FastRMSNorm.load( - # prefix=f"{prefix}.input_layernorm", - # weights=weights, - # eps=config.rms_norm_eps, - # ) - # self.post_attention_layernorm = FastRMSNorm.load( - # prefix=f"{prefix}.post_attention_layernorm", - # weights=weights, - # eps=config.rms_norm_eps, - # ) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) def forward( self, @@ -593,7 +436,7 @@ class Llama4TextDecoderLayer(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.input_layernorm(hidden_states) # use local attention mask for ROPE layers if self.use_chunked_attention and chunk_causal_mask is not None: @@ -617,7 +460,7 @@ class Llama4TextDecoderLayer(nn.Module): # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.post_attention_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states, adapter_data) hidden_states = residual + hidden_states.view(residual.shape) return hidden_states @@ -945,13 +788,9 @@ class Llama4VisionMLP2(torch.nn.Module): def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) - torch_save(hidden_states, f"trans.mlp.fc1.hidden_states.pt") hidden_states = self.activation_fn(hidden_states) - torch_save(hidden_states, f"trans.mlp.activation_fn.hidden_states.pt") hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) - torch_save(hidden_states, f"trans.mlp.dropout.hidden_states.pt") hidden_states = self.fc2(hidden_states) - torch_save(hidden_states, f"trans.mlp.fc2.hidden_states.pt") return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again class Llama4MultiModalProjector(nn.Module): @@ -973,19 +812,14 @@ def pixel_shuffle(input_tensor, shuffle_ratio): input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() - torch_save(input_tensor, f"pixel_shuffle.input_tensor.pt") reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) - torch_save(reshaped_tensor, f"pixel_shuffle.reshaped_tensor.pt") reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - torch_save(reshaped_tensor, f"pixel_shuffle.permute.reshaped_tensor.pt") reshaped_tensor = reshaped_tensor.view( batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) ) - torch_save(reshaped_tensor, f"pixel_shuffle.final_viewed_tensor.pt") reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) - torch_save(output_tensor, f"pixel_shuffle.output_tensor.pt") return output_tensor @@ -1019,30 +853,6 @@ class Llama4VisionAttention(nn.Module): self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = 1 self.attention_dropout = config.attention_dropout - self.q_proj = TensorParallelColumnLinear.load( - config=config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=True, - ) - # self.k_proj = TensorParallelColumnLinear.load( - # config=config, - # prefix=f"{prefix}.k_proj", - # weights=weights, - # bias=True, - # ) - # self.v_proj = TensorParallelColumnLinear.load( - # config=config, - # prefix=f"{prefix}.v_proj", - # weights=weights, - # bias=True, - # ) - # self.o_proj = TensorParallelRowLinear.load( - # config=config, - # prefix=f"{prefix}.o_proj", - # weights=weights, - # bias=True, - # ) self.qkv_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], @@ -1066,9 +876,6 @@ class Llama4VisionAttention(nn.Module): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - # query_states = self.q_proj(hidden_states).view(hidden_shape) - # key_states = self.k_proj(hidden_states).view(hidden_shape) - # value_states = self.v_proj(hidden_states).view(hidden_shape) qkv = self.qkv_proj(hidden_states) @@ -1090,10 +897,6 @@ class Llama4VisionAttention(nn.Module): key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # if hasattr(self, "num_key_value_groups"): - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 @@ -1217,9 +1020,6 @@ class Llama4UnfoldConvolution(nn.Module): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - # self.linear = TensorParallelColumnLinear.load( - # config=config, prefix=f"{prefix}.linear", weights=weights, bias=False - # ) self.linear = FastLinear.load( config=config, prefix=f"{prefix}.linear", weights=weights, bias=False ) @@ -1239,37 +1039,26 @@ class Llama4VisionRotaryEmbedding(nn.Module): idx = config.image_size // config.patch_size img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) - torch_save(img_idx, f"trans.vision.img_idx.pt") img_idx[-1, -1] = -2 # ID_CLS_TOKEN # Calculate x and y coordinates frequencies_x = img_idx % idx # x coordinates - torch_save(frequencies_x, f"trans.vision.frequencies_x.pt") frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates - torch_save(frequencies_y, f"trans.vision.frequencies_y.pt") # Calculate frequency components freq_dim = config.hidden_size // config.num_attention_heads // 2 rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim)) - torch_save(rope_freq, f"trans.vision.rope_freq.pt") # Compute frequencies for x and y directions freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) - torch_save(freqs_x, f"trans.vision.freqs_x.pt") freqs_x = freqs_x.repeat_interleave(2, dim=-1) - torch_save(freqs_x, f"trans.vision.repeat.freqs_x.pt") freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) - torch_save(freqs_y, f"trans.vision.freqs_y.pt") freqs_y = freqs_y.repeat_interleave(2, dim=-1) - torch_save(freqs_y, f"trans.vision.repeat.freqs_y.pt") # Combine frequencies and mask special tokens freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) - torch_save(freqs, f"trans.vision.freqs.pt") - #freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) - #freq_cis = torch.concat([torch.cos(freqs), torch.sin(freqs)], dim=-1) self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 def forward(self, hidden_states): @@ -1304,10 +1093,6 @@ class Llama4VisionModel(nn.Module): weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False ) - log_master( - logger.debug, - f"vision positional_embedding_vlm.shape: {self.positional_embedding_vlm.shape}" - ) self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) # layer norms @@ -1507,4 +1292,4 @@ class Llama4ForConditionalGeneration(nn.Module): attention_mask ) - return logits, speculative_logits + return logits, speculative_logits \ No newline at end of file