Fix the errors of pre-commit

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 18:31:35 +00:00
parent d98116db6e
commit 7533b993d5
4 changed files with 287 additions and 132 deletions

View File

@ -850,14 +850,17 @@ def get_model(
) )
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi() adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1: if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
return StarCoder(model_id=model_id, revision=revision, dtype=dtype) return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom": if model_type == "bloom":
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
return BLOOM( return BLOOM(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, List, Optional, Tuple, Union, Type from typing import List, Optional, Tuple, Union
import torch import torch
import math import math
@ -47,7 +47,6 @@ from text_generation_server.layers.attention import (
HPUPagedAttentionMetadata, HPUPagedAttentionMetadata,
) )
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
load_attention,
FlashLlamaAttention, FlashLlamaAttention,
LlamaMLP, LlamaMLP,
) )
@ -58,6 +57,7 @@ def reshape_for_broadcast(freqs: torch.Tensor, target):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)] shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
return freqs.view(*shape) return freqs.view(*shape)
def apply_rotary_emb( def apply_rotary_emb(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
@ -65,12 +65,12 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape = query.shape query_shape = query.shape
key_shape = key.shape key_shape = key.shape
cos_emb,sin_emb = freqs_ci.split(1, dim=-1) cos_emb, sin_emb = freqs_ci.split(1, dim=-1)
if len(query.shape) == 3: if len(query.shape) == 3:
query = query.unsqueeze(0) query = query.unsqueeze(0)
key = key.unsqueeze(0) key = key.unsqueeze(0)
query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)
key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)
q_shape = query_reshaped.shape[:-1] q_shape = query_reshaped.shape[:-1]
@ -78,12 +78,12 @@ def apply_rotary_emb(
sin_emb = reshape_for_broadcast(sin_emb, q_shape) sin_emb = reshape_for_broadcast(sin_emb, q_shape)
x_q, y_q = query_reshaped.unbind(-1) x_q, y_q = query_reshaped.unbind(-1)
x_k, y_k = key_reshaped.unbind(-1) x_k, y_k = key_reshaped.unbind(-1)
x_q_rot = x_q * cos_emb - y_q * sin_emb x_q_rot = x_q * cos_emb - y_q * sin_emb
y_q_rot = x_q * sin_emb + y_q * cos_emb y_q_rot = x_q * sin_emb + y_q * cos_emb
x_k_rot = x_k * cos_emb - y_k * sin_emb x_k_rot = x_k * cos_emb - y_k * sin_emb
y_k_rot = x_k * sin_emb + y_k * cos_emb y_k_rot = x_k * sin_emb + y_k * cos_emb
query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)
key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)
query_out = query_out.view(*query_shape) query_out = query_out.view(*query_shape)
@ -99,7 +99,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1: if n_rep == 1:
return hidden_states return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@ -108,11 +110,18 @@ class Llama4TextExperts(nn.Module):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size // weights.process_group.size() self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False) self.gate_up_proj = nn.Parameter(
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False) weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
requires_grad=False,
)
self.down_proj = nn.Parameter(
weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -128,18 +137,18 @@ class Llama4TextExperts(nn.Module):
Returns: Returns:
torch.Tensor torch.Tensor
""" """
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, gate_up_proj) gate_up = torch.bmm(hidden_states, gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
next_states = torch.bmm((up * self.act_fn(gate)), down_proj) next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
next_states = next_states.view(-1, self.hidden_size) next_states = next_states.view(-1, self.hidden_size)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(next_states, group=self.process_group) torch.distributed.all_reduce(next_states, group=self.process_group)
return next_states return next_states
@ -171,9 +180,7 @@ class Llama4TextMLP(nn.Module):
def forward(self, x): def forward(self, x):
gate_up_states = self.gate_up_proj(x) gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj( return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
class Llama4TextL2Norm(torch.nn.Module): class Llama4TextL2Norm(torch.nn.Module):
@ -202,11 +209,17 @@ class Llama4TextMoe(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) self.experts = Llama4TextExperts(
self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False) config=config, prefix=f"{prefix}.experts", weights=weights
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) )
self.router = FastLinear.load(
config=config, prefix=f"{prefix}.router", weights=weights, bias=False
)
self.shared_expert = Llama4TextMLP(
config=config, prefix=f"{prefix}.shared_expert", weights=weights
)
self.process_group = weights.process_group self.process_group = weights.process_group
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
seq_len, hidden_dim = hidden_states.shape seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim) hidden_states = hidden_states.view(-1, self.hidden_dim)
@ -215,16 +228,19 @@ class Llama4TextMoe(nn.Module):
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = ( router_scores = (
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
) )
# We do this to make sure we have -inf for non topK tokens before going through the ! # We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
router_indices = ( router_indices = (
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) torch.arange(tokens_per_expert, device=hidden_states.device)
.view(1, -1)
.expand(router_scores.size(0), -1)
) )
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
routed_in = torch.gather( routed_in = torch.gather(
input=hidden_states, input=hidden_states,
@ -232,16 +248,17 @@ class Llama4TextMoe(nn.Module):
index=router_indices, index=router_indices,
).to(hidden_states.device) ).to(hidden_states.device)
# we gather inputs corresponding to each expert based on the router indices # we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1) routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in) routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states) out = self.shared_expert(hidden_states)
# now that we finished expert computation -> we scatter add because we gathered previously # now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP! # this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)) out.scatter_add_(
dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
)
return out return out
@ -262,10 +279,15 @@ class Llama4TextRotaryEmbedding(nn.Module):
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
def forward(self, x, position_ids): def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
origin_device = x.device device_type = (
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
inv_freq_expanded = inv_freq_expanded.to(device_type) inv_freq_expanded = inv_freq_expanded.to(device_type)
position_ids_expanded = position_ids_expanded.to(device_type) position_ids_expanded = position_ids_expanded.to(device_type)
with torch.autocast(device_type=device_type, enabled=False): # Force float32 with torch.autocast(device_type=device_type, enabled=False): # Force float32
@ -273,7 +295,7 @@ class Llama4TextRotaryEmbedding(nn.Module):
cos = torch.cos(freqs) * self.attention_scaling cos = torch.cos(freqs) * self.attention_scaling
sin = torch.sin(freqs) * self.attention_scaling sin = torch.sin(freqs) * self.attention_scaling
cos = cos.reshape(-1, 1, cos.shape[-1]) cos = cos.reshape(-1, 1, cos.shape[-1])
sin = sin.reshape(-1, 1, sin.shape[-1]) sin = sin.reshape(-1, 1, sin.shape[-1])
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
return freqs_cis return freqs_cis
@ -286,15 +308,19 @@ class Llama4TextAttention(FlashLlamaAttention):
super().__init__(layer_idx, prefix, config, weights) super().__init__(layer_idx, prefix, config, weights)
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn_scale = config.attn_scale self.attn_scale = config.attn_scale
self.floor_scale = config.floor_scale self.floor_scale = config.floor_scale
self.attn_temperature_tuning = config.attn_temperature_tuning self.attn_temperature_tuning = config.attn_temperature_tuning
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
if self.config.use_qk_norm and self.use_rope: if self.config.use_qk_norm and self.use_rope:
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
@ -323,7 +349,7 @@ class Llama4TextAttention(FlashLlamaAttention):
], ],
dim=-1, dim=-1,
) )
query_states = query_states.view(hidden_shape) query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape) key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) value_states = value_states.view(hidden_shape)
@ -347,7 +373,11 @@ class Llama4TextAttention(FlashLlamaAttention):
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
if self.attn_temperature_tuning and not self.use_rope: if self.attn_temperature_tuning and not self.use_rope:
attn_scales = ( attn_scales = (
torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 torch.log(
torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
)
* self.attn_scale
+ 1.0
) )
attn_scales = attn_scales.view(*input_shape, 1, 1) attn_scales = attn_scales.view(*input_shape, 1, 1)
query_states = (query_states * attn_scales).to(query_states.dtype) query_states = (query_states * attn_scales).to(query_states.dtype)
@ -355,9 +385,15 @@ class Llama4TextAttention(FlashLlamaAttention):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# sdpa # sdpa
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) 1, 2
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) )
key = key_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value = value_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key = repeat_kv(key, self.num_key_value_groups) key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups)
@ -395,14 +431,16 @@ class Llama4TextAttention(FlashLlamaAttention):
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, adapter_data) attn_output = self.o_proj(attn_output, adapter_data)
return attn_output return attn_output
class Llama4TextDecoderLayer(nn.Module): class Llama4TextDecoderLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_idx): def __init__(self, prefix, config, weights, layer_idx):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx) self.self_attn = Llama4TextAttention(
f"{prefix}.self_attn", config, weights, layer_idx
)
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
self.is_moe_layer = layer_idx in config.moe_layers self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse if self.is_moe_layer: # the 128E model interleaves dense / sparse
@ -411,15 +449,15 @@ class Llama4TextDecoderLayer(nn.Module):
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", prefix=f"{prefix}.input_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
def forward( def forward(
self, self,
@ -434,7 +472,9 @@ class Llama4TextDecoderLayer(nn.Module):
chunk_causal_mask: Optional[torch.Tensor] = None, chunk_causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states) hidden_states, _ = self.input_layernorm(hidden_states)
@ -465,6 +505,7 @@ class Llama4TextDecoderLayer(nn.Module):
hidden_states = residual + hidden_states.view(residual.shape) hidden_states = residual + hidden_states.view(residual.shape)
return hidden_states return hidden_states
class Llama4TextModel(nn.Module): class Llama4TextModel(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
@ -473,12 +514,22 @@ class Llama4TextModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights) self.embed_tokens = TensorParallelEmbedding(
self.layers = nn.ModuleList( prefix=f"{prefix}.embed_tokens", weights=weights
[Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
self.layers = nn.ModuleList(
#self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights) [
Llama4TextDecoderLayer(
prefix=f"{prefix}.layers.{layer_idx}",
config=config,
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
# self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
@ -500,7 +551,7 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
bs = seqlen.input_lengths.shape[0] bs = seqlen.input_lengths.shape[0]
seq_len = inputs_embeds.shape[0] / bs seq_len = inputs_embeds.shape[0] / bs
@ -508,11 +559,16 @@ class Llama4TextModel(nn.Module):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask, chunk_causal_mask = self._update_causal_mask( causal_mask, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False attention_mask,
inputs_embeds.view(bs, int(seq_len), -1),
cache_position,
None,
output_attentions=False,
use_cache=False,
) )
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -530,9 +586,8 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta=hpu_attention_meta, hpu_attention_meta=hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states) hidden_states, _ = self.norm(hidden_states)
return hidden_states return hidden_states
def _update_causal_mask( def _update_causal_mask(
@ -547,7 +602,10 @@ class Llama4TextModel(nn.Module):
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash return (
attention_mask,
attention_mask,
) # flash does not support chunked attn TODO support flash
return None, None return None, None
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
@ -561,7 +619,11 @@ class Llama4TextModel(nn.Module):
if past_key_values is not None: if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
else: else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length full_cache_length = (
attention_mask.shape[-1]
if attention_mask is not None
else sequence_length
)
cond1 = first_cache_position >= attention_chunk_size cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & ( cond2 = (first_cache_position < attention_chunk_size) & (
@ -571,7 +633,9 @@ class Llama4TextModel(nn.Module):
torch.where( torch.where(
cond1, cond1,
attention_chunk_size + sequence_length - 1, attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), torch.where(
cond2, first_cache_position + sequence_length, attention_chunk_size
),
) )
if use_cache if use_cache
else full_cache_length else full_cache_length
@ -586,7 +650,7 @@ class Llama4TextModel(nn.Module):
dtype=dtype, dtype=dtype,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
device=device device=device,
) )
if full_cache_length > self.config.attention_chunk_size: if full_cache_length > self.config.attention_chunk_size:
start_idx = max(first_cache_position - attention_chunk_size + 1, 0) start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
@ -598,24 +662,37 @@ class Llama4TextModel(nn.Module):
device=device, device=device,
) )
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well local_attention_mask = attention_mask[
:, start_idx:end_idx
] # offset here as well
# It may be smaller than attention_chunk_size -> pad it # It may be smaller than attention_chunk_size -> pad it
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
if requires_padding: if requires_padding:
local_attention_mask = nn.functional.pad( local_attention_mask = nn.functional.pad(
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) local_attention_mask,
(0, attention_chunk_size - local_attention_mask.shape[-1]),
) )
# Depending on the padding, take the query tokens from the end or the cache_position # Depending on the padding, take the query tokens from the end or the cache_position
if not requires_padding: if not requires_padding:
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] chunked_attention_mask = chunked_attention_mask[
None, None, -sequence_length:, :
]
else: else:
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] chunked_attention_mask = chunked_attention_mask[
None, None, cache_position, :
]
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1) chunked_attention_mask = chunked_attention_mask.expand(
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] input_tensor.shape[0], -1, -1, -1
)
chunked_attention_mask = (
chunked_attention_mask * local_attention_mask[:, None, None, :]
)
if self.config._attn_implementation == "eager": if self.config._attn_implementation == "eager":
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype) chunked_attention_mask = torch.where(
chunked_attention_mask == 0, min_dtype, 0.0
).to(dtype)
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
@ -628,10 +705,15 @@ class Llama4TextModel(nn.Module):
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213 # Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: if (
self.config._attn_implementation == "sdpa"
and chunked_attention_mask is not None
):
chunked_attention_mask = chunked_attention_mask.bool() chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask.bool() causal_mask = causal_mask.bool()
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
@ -661,7 +743,8 @@ class Llama4TextModel(nn.Module):
""" """
arange_vector = torch.arange(start, end, device=device) arange_vector = torch.arange(start, end, device=device)
block_pos = torch.abs( block_pos = torch.abs(
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size arange_vector.unsqueeze(0) // attention_chunk_size
- arange_vector.unsqueeze(1) // attention_chunk_size
) )
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0) mask = (block_pos == 0) & (token_pos <= 0)
@ -706,25 +789,33 @@ class Llama4TextModel(nn.Module):
else: else:
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = torch.full( causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device (sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
) )
if sequence_length != 1: if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1) causal_mask *= torch.arange(
target_length, device=device
) > cache_position.to(device).reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1] mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(device)
padding_mask = padding_mask == 0 padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[
padding_mask, min_dtype :, :, :, :mask_length
) ].masked_fill(padding_mask, min_dtype)
return causal_mask return causal_mask
class Llama4ForCausalLM(nn.Module): class Llama4ForCausalLM(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -751,7 +842,7 @@ class Llama4ForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
@ -791,7 +882,10 @@ class Llama4VisionMLP2(torch.nn.Module):
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again return self.activation_fn(
hidden_states
) # TODO: check if we need to apply activation again
class Llama4MultiModalProjector(nn.Module): class Llama4MultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
@ -812,10 +906,15 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size() batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) reshaped_tensor = input_tensor.view(
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view( reshaped_tensor = reshaped_tensor.view(
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
) )
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
@ -827,14 +926,19 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2)) self.inner_dim = int(
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
)
self.output_dim = config.projector_output_dim self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP2(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = Llama4VisionMLP2(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
return self.mlp(encoded_patches) return self.mlp(encoded_patches)
# TODO there is a different RoPE for vision encoder, defined as below # TODO there is a different RoPE for vision encoder, defined as below
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
ndim = query.ndim ndim = query.ndim
@ -877,8 +981,7 @@ class Llama4VisionAttention(nn.Module):
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.qkv_proj(hidden_states) qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split( query_states, key_states, value_states = qkv.split(
[ [
self.head_dim * self.num_heads, self.head_dim * self.num_heads,
@ -890,18 +993,24 @@ class Llama4VisionAttention(nn.Module):
query_states = query_states.view(hidden_shape) query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape) key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) value_states = value_states.view(hidden_shape)
query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) query_states, key_states = apply_rotary_emb(
query_states, key_states, freqs_ci=freqs_ci
)
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
dropout_p=0,
) )
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -920,7 +1029,6 @@ class Llama4VisionMLP(nn.Module):
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
@ -987,17 +1095,21 @@ class Llama4VisionEncoder(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
Llama4VisionEncoderLayer(prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights) [
for layer_id in range(config.num_hidden_layers) Llama4VisionEncoderLayer(
]) prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.config = config self.config = config
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]: ) -> Union[Tuple, BaseModelOutput]:
@ -1024,7 +1136,6 @@ class Llama4UnfoldConvolution(nn.Module):
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states) hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1) hidden_states = hidden_states.permute(0, 2, 1)
@ -1037,27 +1148,37 @@ class Llama4VisionRotaryEmbedding(nn.Module):
super().__init__() super().__init__()
# Calculate image grid indices # Calculate image grid indices
idx = config.image_size // config.patch_size idx = config.image_size // config.patch_size
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) img_idx = torch.arange(
idx**2, dtype=torch.int32, device=weights.device
).reshape(idx**2, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # ID_CLS_TOKEN img_idx[-1, -1] = -2 # ID_CLS_TOKEN
# Calculate x and y coordinates # Calculate x and y coordinates
frequencies_x = img_idx % idx # x coordinates frequencies_x = img_idx % idx # x coordinates
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
# Calculate frequency components # Calculate frequency components
freq_dim = config.hidden_size // config.num_attention_heads // 2 freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim)) rope_freq = 1.0 / (
config.rope_theta
** (
torch.arange(0, freq_dim, 2, device=weights.device)[
: (freq_dim // 2)
].float()
/ freq_dim
)
)
# Compute frequencies for x and y directions # Compute frequencies for x and y directions
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
freqs_x = freqs_x.repeat_interleave(2, dim=-1) freqs_x = freqs_x.repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
freqs_y = freqs_y.repeat_interleave(2, dim=-1) freqs_y = freqs_y.repeat_interleave(2, dim=-1)
# Combine frequencies and mask special tokens # Combine frequencies and mask special tokens
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
@ -1090,9 +1211,10 @@ class Llama4VisionModel(nn.Module):
) )
self.positional_embedding_vlm = nn.Parameter( self.positional_embedding_vlm = nn.Parameter(
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
requires_grad=False,
) )
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
# layer norms # layer norms
@ -1126,22 +1248,31 @@ class Llama4VisionModel(nn.Module):
# Add cls token # Add cls token
hidden_state = hidden_state.reshape( hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim batch_size_times_num_tiles * num_concurrent_media * num_chunks,
num_patches,
hidden_dim,
)
class_embedding = self.class_embedding.expand(
hidden_state.shape[0], 1, hidden_state.shape[-1]
) )
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
hidden_state = torch.cat([hidden_state, class_embedding], dim=1) hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1 num_patches += 1
# Position embeddings # Position embeddings
hidden_state = hidden_state.reshape( hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim batch_size_times_num_tiles * num_concurrent_media,
num_chunks,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device
) )
positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
hidden_state = hidden_state + positional_embedding hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state) hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
freqs_ci = self.rotary_embedding(pixel_values) freqs_ci = self.rotary_embedding(pixel_values)
hidden_state = self.model( hidden_state = self.model(
hidden_state, hidden_state,
attention_mask=None, attention_mask=None,
@ -1156,6 +1287,7 @@ class Llama4VisionModel(nn.Module):
hidden_state = self.vision_adapter(hidden_state) hidden_state = self.vision_adapter(hidden_state)
return hidden_state return hidden_state
class Llama4ForConditionalGeneration(nn.Module): class Llama4ForConditionalGeneration(nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
@ -1170,16 +1302,18 @@ class Llama4ForConditionalGeneration(nn.Module):
self.vision_model = Llama4VisionModel( self.vision_model = Llama4VisionModel(
prefix="vision_model", config=config.vision_config, weights=weights prefix="vision_model", config=config.vision_config, weights=weights
) )
self.multi_modal_projector = Llama4MultiModalProjector( self.multi_modal_projector = Llama4MultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights prefix="multi_modal_projector", config=config, weights=weights
) )
self.text_model = Llama4ForCausalLM( self.text_model = Llama4ForCausalLM(
prefix="language_model", config=config.text_config, weights=weights prefix="language_model", config=config.text_config, weights=weights
) )
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.config = config self.config = config
self.dtype = weights.dtype self.dtype = weights.dtype
self.device = weights.device self.device = weights.device
@ -1208,7 +1342,9 @@ class Llama4ForConditionalGeneration(nn.Module):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
if vision_feature_select_strategy not in ["default", "full"]: if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}") raise ValueError(
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
)
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
hidden_state = self.vision_model(pixel_values) hidden_state = self.vision_model(pixel_values)
return hidden_state return hidden_state
@ -1232,7 +1368,7 @@ class Llama4ForConditionalGeneration(nn.Module):
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs, **lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
def _get_padding_mask(input_ids, pad_token_id=0): def _get_padding_mask(input_ids, pad_token_id=0):
return (input_ids != pad_token_id).long() return (input_ids != pad_token_id).long()
@ -1275,11 +1411,15 @@ class Llama4ForConditionalGeneration(nn.Module):
f"but multi_modal_projector returned {projected_vision_flat.size(0)}" f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
) )
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) expanded_mask = final_mask_1d.unsqueeze(-1).expand(
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat) -1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
logits, speculative_logits= self.text_model( logits, speculative_logits = self.text_model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
@ -1289,7 +1429,7 @@ class Llama4ForConditionalGeneration(nn.Module):
hpu_attention_meta, hpu_attention_meta,
adapter_data, adapter_data,
lm_head_indices, lm_head_indices,
attention_mask attention_mask,
) )
return logits, speculative_logits return logits, speculative_logits

View File

@ -1,18 +1,31 @@
from packaging.version import Version from packaging.version import Version
from packaging import version from packaging import version
import subprocess import subprocess
def get_driver_version(): def get_driver_version():
""" """
Returns the driver version. Returns the driver version.
""" """
# Enable console printing for `hl-smi` check # Enable console printing for `hl-smi` check
output = subprocess.run( output = subprocess.run(
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"} "hl-smi",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={"ENABLE_CONSOLE": "true"},
) )
if output.returncode == 0 and output.stdout: if output.returncode == 0 and output.stdout:
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0]) return version.parse(
output.stdout.split("\n")[2]
.replace(" ", "")
.split(":")[1][:-1]
.split("-")[0]
)
return None return None
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0") MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")

View File

@ -315,7 +315,6 @@ class Weights:
tensors_slices += range(block_offset + start, block_offset + stop) tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size block_offset += block_size
if dim == 0: if dim == 0:
tensor = slice_[tensors_slices, ...] tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2: elif dim == 1 or dim == -2: