Fix the errors of pre-commit

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 18:31:35 +00:00
parent d98116db6e
commit 7533b993d5
4 changed files with 287 additions and 132 deletions

View File

@ -850,14 +850,17 @@ def get_model(
)
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode":
from text_generation_server.models.starcoder import StarCoder
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom":
from text_generation_server.models.bloom import BLOOM
return BLOOM(
model_id=model_id,
revision=revision,

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Tuple, Union, Type
from typing import List, Optional, Tuple, Union
import torch
import math
@ -47,7 +47,6 @@ from text_generation_server.layers.attention import (
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
load_attention,
FlashLlamaAttention,
LlamaMLP,
)
@ -58,6 +57,7 @@ def reshape_for_broadcast(freqs: torch.Tensor, target):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
return freqs.view(*shape)
def apply_rotary_emb(
query: torch.Tensor,
key: torch.Tensor,
@ -65,7 +65,7 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape = query.shape
key_shape = key.shape
cos_emb,sin_emb = freqs_ci.split(1, dim=-1)
cos_emb, sin_emb = freqs_ci.split(1, dim=-1)
if len(query.shape) == 3:
query = query.unsqueeze(0)
@ -99,7 +99,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
@ -108,11 +110,18 @@ class Llama4TextExperts(nn.Module):
super().__init__()
self.process_group = weights.process_group
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size // weights.process_group.size()
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False)
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False)
self.gate_up_proj = nn.Parameter(
weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
requires_grad=False,
)
self.down_proj = nn.Parameter(
weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -128,7 +137,7 @@ class Llama4TextExperts(nn.Module):
Returns:
torch.Tensor
"""
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim)
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, gate_up_proj)
@ -171,9 +180,7 @@ class Llama4TextMLP(nn.Module):
def forward(self, x):
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
class Llama4TextL2Norm(torch.nn.Module):
@ -202,9 +209,15 @@ class Llama4TextMoe(nn.Module):
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights)
self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False)
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights)
self.experts = Llama4TextExperts(
config=config, prefix=f"{prefix}.experts", weights=weights
)
self.router = FastLinear.load(
config=config, prefix=f"{prefix}.router", weights=weights, bias=False
)
self.shared_expert = Llama4TextMLP(
config=config, prefix=f"{prefix}.shared_expert", weights=weights
)
self.process_group = weights.process_group
def forward(self, hidden_states, adapter_data):
@ -215,16 +228,19 @@ class Llama4TextMoe(nn.Module):
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = (
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
# We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
router_indices = (
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
torch.arange(tokens_per_expert, device=hidden_states.device)
.view(1, -1)
.expand(router_scores.size(0), -1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
routed_in = torch.gather(
input=hidden_states,
@ -232,7 +248,6 @@ class Llama4TextMoe(nn.Module):
index=router_indices,
).to(hidden_states.device)
# we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in)
@ -241,7 +256,9 @@ class Llama4TextMoe(nn.Module):
# now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP!
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
out.scatter_add_(
dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
)
return out
@ -262,10 +279,15 @@ class Llama4TextRotaryEmbedding(nn.Module):
self.original_inv_freq = self.inv_freq
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
origin_device = x.device
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
inv_freq_expanded = inv_freq_expanded.to(device_type)
position_ids_expanded = position_ids_expanded.to(device_type)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
@ -286,8 +308,12 @@ class Llama4TextAttention(FlashLlamaAttention):
super().__init__(layer_idx, prefix, config, weights)
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attn_scale = config.attn_scale
self.floor_scale = config.floor_scale
@ -347,7 +373,11 @@ class Llama4TextAttention(FlashLlamaAttention):
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
if self.attn_temperature_tuning and not self.use_rope:
attn_scales = (
torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
torch.log(
torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
)
* self.attn_scale
+ 1.0
)
attn_scales = attn_scales.view(*input_shape, 1, 1)
query_states = (query_states * attn_scales).to(query_states.dtype)
@ -355,9 +385,15 @@ class Llama4TextAttention(FlashLlamaAttention):
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
1, 2
)
key = key_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value = value_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups)
@ -402,7 +438,9 @@ class Llama4TextDecoderLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx)
self.self_attn = Llama4TextAttention(
f"{prefix}.self_attn", config, weights, layer_idx
)
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse
@ -411,15 +449,15 @@ class Llama4TextDecoderLayer(nn.Module):
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
@ -434,7 +472,9 @@ class Llama4TextDecoderLayer(nn.Module):
chunk_causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
@ -465,6 +505,7 @@ class Llama4TextDecoderLayer(nn.Module):
hidden_states = residual + hidden_states.view(residual.shape)
return hidden_states
class Llama4TextModel(nn.Module):
def __init__(self, prefix, config, weights):
@ -473,12 +514,22 @@ class Llama4TextModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
[
Llama4TextDecoderLayer(
prefix=f"{prefix}.layers.{layer_idx}",
config=config,
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
#self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
# self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
weights=weights,
@ -510,7 +561,12 @@ class Llama4TextModel(nn.Module):
position_ids = cache_position.unsqueeze(0)
causal_mask, chunk_causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False
attention_mask,
inputs_embeds.view(bs, int(seq_len), -1),
cache_position,
None,
output_attentions=False,
use_cache=False,
)
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
@ -530,7 +586,6 @@ class Llama4TextModel(nn.Module):
hpu_attention_meta=hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states)
return hidden_states
@ -547,7 +602,10 @@ class Llama4TextModel(nn.Module):
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
return (
attention_mask,
attention_mask,
) # flash does not support chunked attn TODO support flash
return None, None
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
@ -561,7 +619,11 @@ class Llama4TextModel(nn.Module):
if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
else:
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
full_cache_length = (
attention_mask.shape[-1]
if attention_mask is not None
else sequence_length
)
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
@ -571,7 +633,9 @@ class Llama4TextModel(nn.Module):
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
torch.where(
cond2, first_cache_position + sequence_length, attention_chunk_size
),
)
if use_cache
else full_cache_length
@ -586,7 +650,7 @@ class Llama4TextModel(nn.Module):
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
device=device
device=device,
)
if full_cache_length > self.config.attention_chunk_size:
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
@ -598,24 +662,37 @@ class Llama4TextModel(nn.Module):
device=device,
)
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well
local_attention_mask = attention_mask[
:, start_idx:end_idx
] # offset here as well
# It may be smaller than attention_chunk_size -> pad it
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
if requires_padding:
local_attention_mask = nn.functional.pad(
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1])
local_attention_mask,
(0, attention_chunk_size - local_attention_mask.shape[-1]),
)
# Depending on the padding, take the query tokens from the end or the cache_position
if not requires_padding:
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :]
chunked_attention_mask = chunked_attention_mask[
None, None, -sequence_length:, :
]
else:
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :]
chunked_attention_mask = chunked_attention_mask[
None, None, cache_position, :
]
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1)
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :]
chunked_attention_mask = chunked_attention_mask.expand(
input_tensor.shape[0], -1, -1, -1
)
chunked_attention_mask = (
chunked_attention_mask * local_attention_mask[:, None, None, :]
)
if self.config._attn_implementation == "eager":
min_dtype = torch.finfo(dtype).min
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype)
chunked_attention_mask = torch.where(
chunked_attention_mask == 0, min_dtype, 0.0
).to(dtype)
if (
self.config._attn_implementation == "sdpa"
@ -628,10 +705,15 @@ class Llama4TextModel(nn.Module):
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
if (
self.config._attn_implementation == "sdpa"
and chunked_attention_mask is not None
):
chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask.bool()
if AttentionMaskConverter._ignore_causal_mask_sdpa(
@ -661,7 +743,8 @@ class Llama4TextModel(nn.Module):
"""
arange_vector = torch.arange(start, end, device=device)
block_pos = torch.abs(
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size
arange_vector.unsqueeze(0) // attention_chunk_size
- arange_vector.unsqueeze(1) // attention_chunk_size
)
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
@ -706,25 +789,33 @@ class Llama4TextModel(nn.Module):
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.to(device).reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class Llama4ForCausalLM(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
@ -791,7 +882,10 @@ class Llama4VisionMLP2(torch.nn.Module):
hidden_states = self.activation_fn(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again
return self.activation_fn(
hidden_states
) # TODO: check if we need to apply activation again
class Llama4MultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
@ -812,10 +906,15 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
reshaped_tensor = input_tensor.view(
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
@ -827,14 +926,19 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
self.inner_dim = int(
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
)
self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP2(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = Llama4VisionMLP2(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
# TODO there is a different RoPE for vision encoder, defined as below
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
ndim = query.ndim
@ -878,7 +982,6 @@ class Llama4VisionAttention(nn.Module):
qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
@ -891,15 +994,21 @@ class Llama4VisionAttention(nn.Module):
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape)
query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)
query_states, key_states = apply_rotary_emb(
query_states, key_states, freqs_ci=freqs_ci
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
dropout_p=0,
)
attn_output = attn_output.transpose(1, 2)
@ -920,7 +1029,6 @@ class Llama4VisionMLP(nn.Module):
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
@ -987,17 +1095,21 @@ class Llama4VisionEncoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Llama4VisionEncoderLayer(prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights)
for layer_id in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Llama4VisionEncoderLayer(
prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
@ -1024,7 +1136,6 @@ class Llama4UnfoldConvolution(nn.Module):
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
@ -1037,21 +1148,31 @@ class Llama4VisionRotaryEmbedding(nn.Module):
super().__init__()
# Calculate image grid indices
idx = config.image_size // config.patch_size
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1)
img_idx = torch.arange(
idx**2, dtype=torch.int32, device=weights.device
).reshape(idx**2, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
# Calculate x and y coordinates
frequencies_x = img_idx % idx # x coordinates
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates
frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
# Calculate frequency components
freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim))
rope_freq = 1.0 / (
config.rope_theta
** (
torch.arange(0, freq_dim, 2, device=weights.device)[
: (freq_dim // 2)
].float()
/ freq_dim
)
)
# Compute frequencies for x and y directions
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :])
freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :])
freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
# Combine frequencies and mask special tokens
@ -1090,7 +1211,8 @@ class Llama4VisionModel(nn.Module):
)
self.positional_embedding_vlm = nn.Parameter(
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False
weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
requires_grad=False,
)
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
@ -1126,17 +1248,26 @@ class Llama4VisionModel(nn.Module):
# Add cls token
hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
num_patches,
hidden_dim,
)
class_embedding = self.class_embedding.expand(
hidden_state.shape[0], 1, hidden_state.shape[-1]
)
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
# Position embeddings
hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim
batch_size_times_num_tiles * num_concurrent_media,
num_chunks,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device
)
positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
@ -1156,6 +1287,7 @@ class Llama4VisionModel(nn.Module):
hidden_state = self.vision_adapter(hidden_state)
return hidden_state
class Llama4ForConditionalGeneration(nn.Module):
def __init__(self, prefix: str, config, weights):
@ -1179,7 +1311,9 @@ class Llama4ForConditionalGeneration(nn.Module):
prefix="language_model", config=config.text_config, weights=weights
)
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
@ -1208,7 +1342,9 @@ class Llama4ForConditionalGeneration(nn.Module):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
raise ValueError(
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
hidden_state = self.vision_model(pixel_values)
return hidden_state
@ -1275,11 +1411,15 @@ class Llama4ForConditionalGeneration(nn.Module):
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(
expanded_mask, projected_vision_flat
)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
logits, speculative_logits= self.text_model(
logits, speculative_logits = self.text_model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
@ -1289,7 +1429,7 @@ class Llama4ForConditionalGeneration(nn.Module):
hpu_attention_meta,
adapter_data,
lm_head_indices,
attention_mask
attention_mask,
)
return logits, speculative_logits

View File

@ -1,18 +1,31 @@
from packaging.version import Version
from packaging import version
import subprocess
def get_driver_version():
"""
Returns the driver version.
"""
# Enable console printing for `hl-smi` check
output = subprocess.run(
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"}
"hl-smi",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={"ENABLE_CONSOLE": "true"},
)
if output.returncode == 0 and output.stdout:
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0])
return version.parse(
output.stdout.split("\n")[2]
.replace(" ", "")
.split(":")[1][:-1]
.split("-")[0]
)
return None
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")

View File

@ -315,7 +315,6 @@ class Weights:
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2: