mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix the errors of pre-commit
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
d98116db6e
commit
7533b993d5
@ -850,14 +850,17 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
adapt_transformers_to_gaudi()
|
adapt_transformers_to_gaudi()
|
||||||
if SDP_ON_BF16 == 1:
|
if SDP_ON_BF16 == 1:
|
||||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
from text_generation_server.models.starcoder import StarCoder
|
from text_generation_server.models.starcoder import StarCoder
|
||||||
|
|
||||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
|
|
||||||
return BLOOM(
|
return BLOOM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union, Type
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
@ -47,7 +47,6 @@ from text_generation_server.layers.attention import (
|
|||||||
HPUPagedAttentionMetadata,
|
HPUPagedAttentionMetadata,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
load_attention,
|
|
||||||
FlashLlamaAttention,
|
FlashLlamaAttention,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
)
|
)
|
||||||
@ -58,6 +57,7 @@ def reshape_for_broadcast(freqs: torch.Tensor, target):
|
|||||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
|
||||||
return freqs.view(*shape)
|
return freqs.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(
|
def apply_rotary_emb(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
@ -65,12 +65,12 @@ def apply_rotary_emb(
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
cos_emb,sin_emb = freqs_ci.split(1, dim=-1)
|
cos_emb, sin_emb = freqs_ci.split(1, dim=-1)
|
||||||
|
|
||||||
if len(query.shape) == 3:
|
if len(query.shape) == 3:
|
||||||
query = query.unsqueeze(0)
|
query = query.unsqueeze(0)
|
||||||
key = key.unsqueeze(0)
|
key = key.unsqueeze(0)
|
||||||
|
|
||||||
query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)
|
query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)
|
||||||
key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)
|
key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)
|
||||||
q_shape = query_reshaped.shape[:-1]
|
q_shape = query_reshaped.shape[:-1]
|
||||||
@ -78,12 +78,12 @@ def apply_rotary_emb(
|
|||||||
sin_emb = reshape_for_broadcast(sin_emb, q_shape)
|
sin_emb = reshape_for_broadcast(sin_emb, q_shape)
|
||||||
x_q, y_q = query_reshaped.unbind(-1)
|
x_q, y_q = query_reshaped.unbind(-1)
|
||||||
x_k, y_k = key_reshaped.unbind(-1)
|
x_k, y_k = key_reshaped.unbind(-1)
|
||||||
|
|
||||||
x_q_rot = x_q * cos_emb - y_q * sin_emb
|
x_q_rot = x_q * cos_emb - y_q * sin_emb
|
||||||
y_q_rot = x_q * sin_emb + y_q * cos_emb
|
y_q_rot = x_q * sin_emb + y_q * cos_emb
|
||||||
x_k_rot = x_k * cos_emb - y_k * sin_emb
|
x_k_rot = x_k * cos_emb - y_k * sin_emb
|
||||||
y_k_rot = x_k * sin_emb + y_k * cos_emb
|
y_k_rot = x_k * sin_emb + y_k * cos_emb
|
||||||
|
|
||||||
query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)
|
query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)
|
||||||
key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)
|
key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)
|
||||||
query_out = query_out.view(*query_shape)
|
query_out = query_out.view(*query_shape)
|
||||||
@ -99,7 +99,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
@ -108,11 +110,18 @@ class Llama4TextExperts(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.intermediate_size = config.intermediate_size // weights.process_group.size()
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.expert_dim = self.intermediate_size
|
self.expert_dim = self.intermediate_size
|
||||||
self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False)
|
self.gate_up_proj = nn.Parameter(
|
||||||
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False)
|
weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.down_proj = nn.Parameter(
|
||||||
|
weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
|
||||||
|
)
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
@ -128,18 +137,18 @@ class Llama4TextExperts(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
"""
|
"""
|
||||||
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim)
|
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)
|
||||||
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
|
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
|
||||||
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
|
||||||
gate_up = torch.bmm(hidden_states, gate_up_proj)
|
gate_up = torch.bmm(hidden_states, gate_up_proj)
|
||||||
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
|
||||||
next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
|
next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
|
||||||
next_states = next_states.view(-1, self.hidden_size)
|
next_states = next_states.view(-1, self.hidden_size)
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(next_states, group=self.process_group)
|
torch.distributed.all_reduce(next_states, group=self.process_group)
|
||||||
|
|
||||||
return next_states
|
return next_states
|
||||||
|
|
||||||
|
|
||||||
@ -171,9 +180,7 @@ class Llama4TextMLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
gate_up_states = self.gate_up_proj(x)
|
gate_up_states = self.gate_up_proj(x)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(
|
return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextL2Norm(torch.nn.Module):
|
class Llama4TextL2Norm(torch.nn.Module):
|
||||||
@ -202,11 +209,17 @@ class Llama4TextMoe(nn.Module):
|
|||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights)
|
self.experts = Llama4TextExperts(
|
||||||
self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False)
|
config=config, prefix=f"{prefix}.experts", weights=weights
|
||||||
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights)
|
)
|
||||||
|
self.router = FastLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.router", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
self.shared_expert = Llama4TextMLP(
|
||||||
|
config=config, prefix=f"{prefix}.shared_expert", weights=weights
|
||||||
|
)
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
seq_len, hidden_dim = hidden_states.shape
|
seq_len, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||||
@ -215,16 +228,19 @@ class Llama4TextMoe(nn.Module):
|
|||||||
|
|
||||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||||
router_scores = (
|
router_scores = (
|
||||||
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
|
torch.full_like(router_logits, float("-inf"))
|
||||||
|
.scatter_(1, router_indices, router_top_value)
|
||||||
|
.transpose(0, 1)
|
||||||
)
|
)
|
||||||
# We do this to make sure we have -inf for non topK tokens before going through the !
|
# We do this to make sure we have -inf for non topK tokens before going through the !
|
||||||
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
|
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
|
||||||
router_indices = (
|
router_indices = (
|
||||||
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
|
torch.arange(tokens_per_expert, device=hidden_states.device)
|
||||||
|
.view(1, -1)
|
||||||
|
.expand(router_scores.size(0), -1)
|
||||||
)
|
)
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
||||||
|
|
||||||
|
|
||||||
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
||||||
routed_in = torch.gather(
|
routed_in = torch.gather(
|
||||||
input=hidden_states,
|
input=hidden_states,
|
||||||
@ -232,16 +248,17 @@ class Llama4TextMoe(nn.Module):
|
|||||||
index=router_indices,
|
index=router_indices,
|
||||||
).to(hidden_states.device)
|
).to(hidden_states.device)
|
||||||
|
|
||||||
|
|
||||||
# we gather inputs corresponding to each expert based on the router indices
|
# we gather inputs corresponding to each expert based on the router indices
|
||||||
routed_in = routed_in * router_scores.reshape(-1, 1)
|
routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||||
routed_out = self.experts(routed_in)
|
routed_out = self.experts(routed_in)
|
||||||
out = self.shared_expert(hidden_states)
|
out = self.shared_expert(hidden_states)
|
||||||
|
|
||||||
# now that we finished expert computation -> we scatter add because we gathered previously
|
# now that we finished expert computation -> we scatter add because we gathered previously
|
||||||
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
||||||
# this scales a lot better if you do EP!
|
# this scales a lot better if you do EP!
|
||||||
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
|
out.scatter_add_(
|
||||||
|
dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -262,10 +279,15 @@ class Llama4TextRotaryEmbedding(nn.Module):
|
|||||||
self.original_inv_freq = self.inv_freq
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
def forward(self, x, position_ids):
|
def forward(self, x, position_ids):
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = (
|
||||||
|
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
origin_device = x.device
|
device_type = (
|
||||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
x.device.type
|
||||||
|
if isinstance(x.device.type, str) and x.device.type != "mps"
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
inv_freq_expanded = inv_freq_expanded.to(device_type)
|
inv_freq_expanded = inv_freq_expanded.to(device_type)
|
||||||
position_ids_expanded = position_ids_expanded.to(device_type)
|
position_ids_expanded = position_ids_expanded.to(device_type)
|
||||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
@ -273,7 +295,7 @@ class Llama4TextRotaryEmbedding(nn.Module):
|
|||||||
cos = torch.cos(freqs) * self.attention_scaling
|
cos = torch.cos(freqs) * self.attention_scaling
|
||||||
sin = torch.sin(freqs) * self.attention_scaling
|
sin = torch.sin(freqs) * self.attention_scaling
|
||||||
cos = cos.reshape(-1, 1, cos.shape[-1])
|
cos = cos.reshape(-1, 1, cos.shape[-1])
|
||||||
sin = sin.reshape(-1, 1, sin.shape[-1])
|
sin = sin.reshape(-1, 1, sin.shape[-1])
|
||||||
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
|
freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling
|
||||||
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||||
return freqs_cis
|
return freqs_cis
|
||||||
@ -286,15 +308,19 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
super().__init__(layer_idx, prefix, config, weights)
|
super().__init__(layer_idx, prefix, config, weights)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
self.head_dim = getattr(
|
||||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_key_value_groups = (
|
||||||
|
config.num_attention_heads // config.num_key_value_heads
|
||||||
|
)
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn_scale = config.attn_scale
|
self.attn_scale = config.attn_scale
|
||||||
self.floor_scale = config.floor_scale
|
self.floor_scale = config.floor_scale
|
||||||
self.attn_temperature_tuning = config.attn_temperature_tuning
|
self.attn_temperature_tuning = config.attn_temperature_tuning
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
|
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
|
||||||
|
|
||||||
if self.config.use_qk_norm and self.use_rope:
|
if self.config.use_qk_norm and self.use_rope:
|
||||||
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
|
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
|
||||||
|
|
||||||
@ -323,7 +349,7 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_states = query_states.view(hidden_shape)
|
query_states = query_states.view(hidden_shape)
|
||||||
key_states = key_states.view(hidden_shape)
|
key_states = key_states.view(hidden_shape)
|
||||||
value_states = value_states.view(hidden_shape)
|
value_states = value_states.view(hidden_shape)
|
||||||
@ -347,7 +373,11 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
|
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
|
||||||
if self.attn_temperature_tuning and not self.use_rope:
|
if self.attn_temperature_tuning and not self.use_rope:
|
||||||
attn_scales = (
|
attn_scales = (
|
||||||
torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
torch.log(
|
||||||
|
torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
|
||||||
|
)
|
||||||
|
* self.attn_scale
|
||||||
|
+ 1.0
|
||||||
)
|
)
|
||||||
attn_scales = attn_scales.view(*input_shape, 1, 1)
|
attn_scales = attn_scales.view(*input_shape, 1, 1)
|
||||||
query_states = (query_states * attn_scales).to(query_states.dtype)
|
query_states = (query_states * attn_scales).to(query_states.dtype)
|
||||||
@ -355,9 +385,15 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# sdpa
|
# sdpa
|
||||||
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
|
||||||
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1, 2
|
||||||
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
)
|
||||||
|
key = key_states.view(
|
||||||
|
bs, -1, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value = value_states.view(
|
||||||
|
bs, -1, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
key = repeat_kv(key, self.num_key_value_groups)
|
key = repeat_kv(key, self.num_key_value_groups)
|
||||||
value = repeat_kv(value, self.num_key_value_groups)
|
value = repeat_kv(value, self.num_key_value_groups)
|
||||||
|
|
||||||
@ -395,14 +431,16 @@ class Llama4TextAttention(FlashLlamaAttention):
|
|||||||
|
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output, adapter_data)
|
attn_output = self.o_proj(attn_output, adapter_data)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextDecoderLayer(nn.Module):
|
class Llama4TextDecoderLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, layer_idx):
|
def __init__(self, prefix, config, weights, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx)
|
self.self_attn = Llama4TextAttention(
|
||||||
|
f"{prefix}.self_attn", config, weights, layer_idx
|
||||||
|
)
|
||||||
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
|
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
|
||||||
self.is_moe_layer = layer_idx in config.moe_layers
|
self.is_moe_layer = layer_idx in config.moe_layers
|
||||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||||
@ -411,15 +449,15 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
|
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -434,7 +472,9 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
chunk_causal_mask: Optional[torch.Tensor] = None,
|
chunk_causal_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, _ = self.input_layernorm(hidden_states)
|
hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
@ -465,6 +505,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
hidden_states = residual + hidden_states.view(residual.shape)
|
hidden_states = residual + hidden_states.view(residual.shape)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextModel(nn.Module):
|
class Llama4TextModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
@ -473,12 +514,22 @@ class Llama4TextModel(nn.Module):
|
|||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights)
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
self.layers = nn.ModuleList(
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
[Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
||||||
)
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
#self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
|
[
|
||||||
|
Llama4TextDecoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.norm",
|
prefix=f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -500,7 +551,7 @@ class Llama4TextModel(nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
bs = seqlen.input_lengths.shape[0]
|
bs = seqlen.input_lengths.shape[0]
|
||||||
seq_len = inputs_embeds.shape[0] / bs
|
seq_len = inputs_embeds.shape[0] / bs
|
||||||
@ -508,11 +559,16 @@ class Llama4TextModel(nn.Module):
|
|||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
||||||
attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False
|
attention_mask,
|
||||||
|
inputs_embeds.view(bs, int(seq_len), -1),
|
||||||
|
cache_position,
|
||||||
|
None,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -530,9 +586,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states)
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
@ -547,7 +602,10 @@ class Llama4TextModel(nn.Module):
|
|||||||
):
|
):
|
||||||
if self.config._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||||
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
|
return (
|
||||||
|
attention_mask,
|
||||||
|
attention_mask,
|
||||||
|
) # flash does not support chunked attn TODO support flash
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
|
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
|
||||||
@ -561,7 +619,11 @@ class Llama4TextModel(nn.Module):
|
|||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
||||||
else:
|
else:
|
||||||
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
full_cache_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if attention_mask is not None
|
||||||
|
else sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
cond1 = first_cache_position >= attention_chunk_size
|
cond1 = first_cache_position >= attention_chunk_size
|
||||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||||
@ -571,7 +633,9 @@ class Llama4TextModel(nn.Module):
|
|||||||
torch.where(
|
torch.where(
|
||||||
cond1,
|
cond1,
|
||||||
attention_chunk_size + sequence_length - 1,
|
attention_chunk_size + sequence_length - 1,
|
||||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
torch.where(
|
||||||
|
cond2, first_cache_position + sequence_length, attention_chunk_size
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if use_cache
|
if use_cache
|
||||||
else full_cache_length
|
else full_cache_length
|
||||||
@ -586,7 +650,7 @@ class Llama4TextModel(nn.Module):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
device=device
|
device=device,
|
||||||
)
|
)
|
||||||
if full_cache_length > self.config.attention_chunk_size:
|
if full_cache_length > self.config.attention_chunk_size:
|
||||||
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
||||||
@ -598,24 +662,37 @@ class Llama4TextModel(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well
|
local_attention_mask = attention_mask[
|
||||||
|
:, start_idx:end_idx
|
||||||
|
] # offset here as well
|
||||||
# It may be smaller than attention_chunk_size -> pad it
|
# It may be smaller than attention_chunk_size -> pad it
|
||||||
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
|
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
|
||||||
if requires_padding:
|
if requires_padding:
|
||||||
local_attention_mask = nn.functional.pad(
|
local_attention_mask = nn.functional.pad(
|
||||||
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1])
|
local_attention_mask,
|
||||||
|
(0, attention_chunk_size - local_attention_mask.shape[-1]),
|
||||||
)
|
)
|
||||||
# Depending on the padding, take the query tokens from the end or the cache_position
|
# Depending on the padding, take the query tokens from the end or the cache_position
|
||||||
if not requires_padding:
|
if not requires_padding:
|
||||||
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :]
|
chunked_attention_mask = chunked_attention_mask[
|
||||||
|
None, None, -sequence_length:, :
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :]
|
chunked_attention_mask = chunked_attention_mask[
|
||||||
|
None, None, cache_position, :
|
||||||
|
]
|
||||||
|
|
||||||
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1)
|
chunked_attention_mask = chunked_attention_mask.expand(
|
||||||
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :]
|
input_tensor.shape[0], -1, -1, -1
|
||||||
|
)
|
||||||
|
chunked_attention_mask = (
|
||||||
|
chunked_attention_mask * local_attention_mask[:, None, None, :]
|
||||||
|
)
|
||||||
if self.config._attn_implementation == "eager":
|
if self.config._attn_implementation == "eager":
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype)
|
chunked_attention_mask = torch.where(
|
||||||
|
chunked_attention_mask == 0, min_dtype, 0.0
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.config._attn_implementation == "sdpa"
|
self.config._attn_implementation == "sdpa"
|
||||||
@ -628,10 +705,15 @@ class Llama4TextModel(nn.Module):
|
|||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||||
|
causal_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and chunked_attention_mask is not None
|
||||||
|
):
|
||||||
chunked_attention_mask = chunked_attention_mask.bool()
|
chunked_attention_mask = chunked_attention_mask.bool()
|
||||||
causal_mask = causal_mask.bool()
|
causal_mask = causal_mask.bool()
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
@ -661,7 +743,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
arange_vector = torch.arange(start, end, device=device)
|
arange_vector = torch.arange(start, end, device=device)
|
||||||
block_pos = torch.abs(
|
block_pos = torch.abs(
|
||||||
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size
|
arange_vector.unsqueeze(0) // attention_chunk_size
|
||||||
|
- arange_vector.unsqueeze(1) // attention_chunk_size
|
||||||
)
|
)
|
||||||
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
|
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
|
||||||
mask = (block_pos == 0) & (token_pos <= 0)
|
mask = (block_pos == 0) & (token_pos <= 0)
|
||||||
@ -706,25 +789,33 @@ class Llama4TextModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length),
|
||||||
|
fill_value=min_dtype,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
if sequence_length != 1:
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1)
|
causal_mask *= torch.arange(
|
||||||
|
target_length, device=device
|
||||||
|
) > cache_position.to(device).reshape(-1, 1)
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = (
|
||||||
|
causal_mask.clone()
|
||||||
|
) # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
|
||||||
|
:, None, None, :
|
||||||
|
].to(device)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||||
padding_mask, min_dtype
|
:, :, :, :mask_length
|
||||||
)
|
].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4ForCausalLM(nn.Module):
|
class Llama4ForCausalLM(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -751,7 +842,7 @@ class Llama4ForCausalLM(nn.Module):
|
|||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
@ -791,7 +882,10 @@ class Llama4VisionMLP2(torch.nn.Module):
|
|||||||
hidden_states = self.activation_fn(hidden_states)
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = self.fc2(hidden_states)
|
hidden_states = self.fc2(hidden_states)
|
||||||
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again
|
return self.activation_fn(
|
||||||
|
hidden_states
|
||||||
|
) # TODO: check if we need to apply activation again
|
||||||
|
|
||||||
|
|
||||||
class Llama4MultiModalProjector(nn.Module):
|
class Llama4MultiModalProjector(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
@ -812,10 +906,15 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
|||||||
|
|
||||||
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
||||||
batch_size, height, width, channels = input_tensor.size()
|
batch_size, height, width, channels = input_tensor.size()
|
||||||
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
|
reshaped_tensor = input_tensor.view(
|
||||||
|
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
||||||
|
)
|
||||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||||
reshaped_tensor = reshaped_tensor.view(
|
reshaped_tensor = reshaped_tensor.view(
|
||||||
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
|
batch_size,
|
||||||
|
int(height * shuffle_ratio),
|
||||||
|
int(width * shuffle_ratio),
|
||||||
|
int(channels / (shuffle_ratio**2)),
|
||||||
)
|
)
|
||||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
@ -827,14 +926,19 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||||
self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
|
self.inner_dim = int(
|
||||||
|
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
|
||||||
|
)
|
||||||
self.output_dim = config.projector_output_dim
|
self.output_dim = config.projector_output_dim
|
||||||
self.mlp = Llama4VisionMLP2(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = Llama4VisionMLP2(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||||
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
||||||
return self.mlp(encoded_patches)
|
return self.mlp(encoded_patches)
|
||||||
|
|
||||||
|
|
||||||
# TODO there is a different RoPE for vision encoder, defined as below
|
# TODO there is a different RoPE for vision encoder, defined as below
|
||||||
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
|
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
|
||||||
ndim = query.ndim
|
ndim = query.ndim
|
||||||
@ -877,8 +981,7 @@ class Llama4VisionAttention(nn.Module):
|
|||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
query_states, key_states, value_states = qkv.split(
|
query_states, key_states, value_states = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_dim * self.num_heads,
|
self.head_dim * self.num_heads,
|
||||||
@ -890,18 +993,24 @@ class Llama4VisionAttention(nn.Module):
|
|||||||
query_states = query_states.view(hidden_shape)
|
query_states = query_states.view(hidden_shape)
|
||||||
key_states = key_states.view(hidden_shape)
|
key_states = key_states.view(hidden_shape)
|
||||||
value_states = value_states.view(hidden_shape)
|
value_states = value_states.view(hidden_shape)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)
|
query_states, key_states = apply_rotary_emb(
|
||||||
|
query_states, key_states, freqs_ci=freqs_ci
|
||||||
|
)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention(
|
||||||
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
is_causal=False,
|
||||||
|
dropout_p=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
@ -920,7 +1029,6 @@ class Llama4VisionMLP(nn.Module):
|
|||||||
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.fc1(hidden_states)
|
hidden_states = self.fc1(hidden_states)
|
||||||
hidden_states = self.activation_fn(hidden_states)
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
@ -987,17 +1095,21 @@ class Llama4VisionEncoder(nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
Llama4VisionEncoderLayer(prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights)
|
[
|
||||||
for layer_id in range(config.num_hidden_layers)
|
Llama4VisionEncoderLayer(
|
||||||
])
|
prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
|
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutput]:
|
) -> Union[Tuple, BaseModelOutput]:
|
||||||
|
|
||||||
@ -1024,7 +1136,6 @@ class Llama4UnfoldConvolution(nn.Module):
|
|||||||
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
|
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.unfold(hidden_states)
|
hidden_states = self.unfold(hidden_states)
|
||||||
hidden_states = hidden_states.permute(0, 2, 1)
|
hidden_states = hidden_states.permute(0, 2, 1)
|
||||||
@ -1037,27 +1148,37 @@ class Llama4VisionRotaryEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# Calculate image grid indices
|
# Calculate image grid indices
|
||||||
idx = config.image_size // config.patch_size
|
idx = config.image_size // config.patch_size
|
||||||
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1)
|
img_idx = torch.arange(
|
||||||
|
idx**2, dtype=torch.int32, device=weights.device
|
||||||
|
).reshape(idx**2, 1)
|
||||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||||
|
|
||||||
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
||||||
# Calculate x and y coordinates
|
# Calculate x and y coordinates
|
||||||
frequencies_x = img_idx % idx # x coordinates
|
frequencies_x = img_idx % idx # x coordinates
|
||||||
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates
|
frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
|
||||||
# Calculate frequency components
|
# Calculate frequency components
|
||||||
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
||||||
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim))
|
rope_freq = 1.0 / (
|
||||||
|
config.rope_theta
|
||||||
|
** (
|
||||||
|
torch.arange(0, freq_dim, 2, device=weights.device)[
|
||||||
|
: (freq_dim // 2)
|
||||||
|
].float()
|
||||||
|
/ freq_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Compute frequencies for x and y directions
|
# Compute frequencies for x and y directions
|
||||||
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :])
|
freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
||||||
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
|
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
|
||||||
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :])
|
freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
||||||
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
|
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
# Combine frequencies and mask special tokens
|
# Combine frequencies and mask special tokens
|
||||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||||
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
|
||||||
|
|
||||||
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
|
||||||
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
||||||
|
|
||||||
@ -1090,9 +1211,10 @@ class Llama4VisionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.positional_embedding_vlm = nn.Parameter(
|
self.positional_embedding_vlm = nn.Parameter(
|
||||||
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False
|
weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
|
||||||
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
|
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
|
||||||
|
|
||||||
# layer norms
|
# layer norms
|
||||||
@ -1126,22 +1248,31 @@ class Llama4VisionModel(nn.Module):
|
|||||||
|
|
||||||
# Add cls token
|
# Add cls token
|
||||||
hidden_state = hidden_state.reshape(
|
hidden_state = hidden_state.reshape(
|
||||||
batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim
|
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
|
||||||
|
num_patches,
|
||||||
|
hidden_dim,
|
||||||
|
)
|
||||||
|
class_embedding = self.class_embedding.expand(
|
||||||
|
hidden_state.shape[0], 1, hidden_state.shape[-1]
|
||||||
)
|
)
|
||||||
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
|
|
||||||
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
|
|
||||||
# Position embeddings
|
# Position embeddings
|
||||||
hidden_state = hidden_state.reshape(
|
hidden_state = hidden_state.reshape(
|
||||||
batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim
|
batch_size_times_num_tiles * num_concurrent_media,
|
||||||
|
num_chunks,
|
||||||
|
num_patches,
|
||||||
|
hidden_dim,
|
||||||
|
)
|
||||||
|
positional_embedding = self.positional_embedding_vlm.to(
|
||||||
|
dtype=hidden_state.dtype, device=hidden_state.device
|
||||||
)
|
)
|
||||||
positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
|
|
||||||
hidden_state = hidden_state + positional_embedding
|
hidden_state = hidden_state + positional_embedding
|
||||||
hidden_state = self.layernorm_pre(hidden_state)
|
hidden_state = self.layernorm_pre(hidden_state)
|
||||||
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
|
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
|
||||||
freqs_ci = self.rotary_embedding(pixel_values)
|
freqs_ci = self.rotary_embedding(pixel_values)
|
||||||
|
|
||||||
hidden_state = self.model(
|
hidden_state = self.model(
|
||||||
hidden_state,
|
hidden_state,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
@ -1156,6 +1287,7 @@ class Llama4VisionModel(nn.Module):
|
|||||||
hidden_state = self.vision_adapter(hidden_state)
|
hidden_state = self.vision_adapter(hidden_state)
|
||||||
return hidden_state
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
class Llama4ForConditionalGeneration(nn.Module):
|
class Llama4ForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights):
|
||||||
@ -1170,16 +1302,18 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
self.vision_model = Llama4VisionModel(
|
self.vision_model = Llama4VisionModel(
|
||||||
prefix="vision_model", config=config.vision_config, weights=weights
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||||
prefix="multi_modal_projector", config=config, weights=weights
|
prefix="multi_modal_projector", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
self.text_model = Llama4ForCausalLM(
|
self.text_model = Llama4ForCausalLM(
|
||||||
prefix="language_model", config=config.text_config, weights=weights
|
prefix="language_model", config=config.text_config, weights=weights
|
||||||
)
|
)
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
self.pad_token_id = (
|
||||||
|
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.dtype = weights.dtype
|
self.dtype = weights.dtype
|
||||||
self.device = weights.device
|
self.device = weights.device
|
||||||
@ -1208,7 +1342,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
|
raise ValueError(
|
||||||
|
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
hidden_state = self.vision_model(pixel_values)
|
hidden_state = self.vision_model(pixel_values)
|
||||||
return hidden_state
|
return hidden_state
|
||||||
@ -1232,7 +1368,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
**lm_kwargs,
|
**lm_kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
def _get_padding_mask(input_ids, pad_token_id=0):
|
def _get_padding_mask(input_ids, pad_token_id=0):
|
||||||
return (input_ids != pad_token_id).long()
|
return (input_ids != pad_token_id).long()
|
||||||
|
|
||||||
@ -1275,11 +1411,15 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
|
-1, inputs_embeds.size(-1)
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
|
expanded_mask, projected_vision_flat
|
||||||
|
)
|
||||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||||
|
|
||||||
logits, speculative_logits= self.text_model(
|
logits, speculative_logits = self.text_model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
@ -1289,7 +1429,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
hpu_attention_meta,
|
hpu_attention_meta,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
lm_head_indices,
|
lm_head_indices,
|
||||||
attention_mask
|
attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -1,18 +1,31 @@
|
|||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from packaging import version
|
from packaging import version
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def get_driver_version():
|
def get_driver_version():
|
||||||
"""
|
"""
|
||||||
Returns the driver version.
|
Returns the driver version.
|
||||||
"""
|
"""
|
||||||
# Enable console printing for `hl-smi` check
|
# Enable console printing for `hl-smi` check
|
||||||
output = subprocess.run(
|
output = subprocess.run(
|
||||||
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"}
|
"hl-smi",
|
||||||
|
shell=True,
|
||||||
|
text=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
env={"ENABLE_CONSOLE": "true"},
|
||||||
)
|
)
|
||||||
if output.returncode == 0 and output.stdout:
|
if output.returncode == 0 and output.stdout:
|
||||||
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0])
|
return version.parse(
|
||||||
|
output.stdout.split("\n")[2]
|
||||||
|
.replace(" ", "")
|
||||||
|
.split(":")[1][:-1]
|
||||||
|
.split("-")[0]
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,7 +315,6 @@ class Weights:
|
|||||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||||
block_offset += block_size
|
block_offset += block_size
|
||||||
|
|
||||||
|
|
||||||
if dim == 0:
|
if dim == 0:
|
||||||
tensor = slice_[tensors_slices, ...]
|
tensor = slice_[tensors_slices, ...]
|
||||||
elif dim == 1 or dim == -2:
|
elif dim == 1 or dim == -2:
|
||||||
|
Loading…
Reference in New Issue
Block a user