mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix the errors of pre-commit
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
d98116db6e
commit
7533b993d5
@ -850,14 +850,17 @@ def get_model(
|
||||
)
|
||||
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
adapt_transformers_to_gaudi()
|
||||
if SDP_ON_BF16 == 1:
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
|
||||
if model_type == "gpt_bigcode":
|
||||
from text_generation_server.models.starcoder import StarCoder
|
||||
|
||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||
if model_type == "bloom":
|
||||
from text_generation_server.models.bloom import BLOOM
|
||||
|
||||
return BLOOM(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union, Type
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import math
|
||||
@ -47,7 +47,6 @@ from text_generation_server.layers.attention import (
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
load_attention,
|
||||
FlashLlamaAttention,
|
||||
LlamaMLP,
|
||||
)
|
||||
@ -58,6 +57,7 @@ def reshape_for_broadcast(freqs: torch.Tensor, target):
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
|
||||
return freqs.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -99,7 +99,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
@ -108,11 +110,18 @@ class Llama4TextExperts(nn.Module):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.num_experts = config.num_local_experts
|
||||
self.intermediate_size = config.intermediate_size // weights.process_group.size()
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False)
|
||||
self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False)
|
||||
self.gate_up_proj = nn.Parameter(
|
||||
weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj = nn.Parameter(
|
||||
weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
|
||||
)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -171,9 +180,7 @@ class Llama4TextMLP(nn.Module):
|
||||
def forward(self, x):
|
||||
gate_up_states = self.gate_up_proj(x)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||
)
|
||||
return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class Llama4TextL2Norm(torch.nn.Module):
|
||||
@ -202,9 +209,15 @@ class Llama4TextMoe(nn.Module):
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights)
|
||||
self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False)
|
||||
self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights)
|
||||
self.experts = Llama4TextExperts(
|
||||
config=config, prefix=f"{prefix}.experts", weights=weights
|
||||
)
|
||||
self.router = FastLinear.load(
|
||||
config=config, prefix=f"{prefix}.router", weights=weights, bias=False
|
||||
)
|
||||
self.shared_expert = Llama4TextMLP(
|
||||
config=config, prefix=f"{prefix}.shared_expert", weights=weights
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
@ -215,16 +228,19 @@ class Llama4TextMoe(nn.Module):
|
||||
|
||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1)
|
||||
torch.full_like(router_logits, float("-inf"))
|
||||
.scatter_(1, router_indices, router_top_value)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# We do this to make sure we have -inf for non topK tokens before going through the !
|
||||
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
|
||||
router_indices = (
|
||||
torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
torch.arange(tokens_per_expert, device=hidden_states.device)
|
||||
.view(1, -1)
|
||||
.expand(router_scores.size(0), -1)
|
||||
)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
|
||||
|
||||
|
||||
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
|
||||
routed_in = torch.gather(
|
||||
input=hidden_states,
|
||||
@ -232,7 +248,6 @@ class Llama4TextMoe(nn.Module):
|
||||
index=router_indices,
|
||||
).to(hidden_states.device)
|
||||
|
||||
|
||||
# we gather inputs corresponding to each expert based on the router indices
|
||||
routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||
routed_out = self.experts(routed_in)
|
||||
@ -241,7 +256,9 @@ class Llama4TextMoe(nn.Module):
|
||||
# now that we finished expert computation -> we scatter add because we gathered previously
|
||||
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
|
||||
# this scales a lot better if you do EP!
|
||||
out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim))
|
||||
out.scatter_add_(
|
||||
dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@ -262,10 +279,15 @@ class Llama4TextRotaryEmbedding(nn.Module):
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
origin_device = x.device
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
device_type = (
|
||||
x.device.type
|
||||
if isinstance(x.device.type, str) and x.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
inv_freq_expanded = inv_freq_expanded.to(device_type)
|
||||
position_ids_expanded = position_ids_expanded.to(device_type)
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
@ -286,8 +308,12 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
super().__init__(layer_idx, prefix, config, weights)
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_key_value_groups = (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn_scale = config.attn_scale
|
||||
self.floor_scale = config.floor_scale
|
||||
@ -347,7 +373,11 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
attn_scales = (
|
||||
torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
||||
torch.log(
|
||||
torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
|
||||
)
|
||||
* self.attn_scale
|
||||
+ 1.0
|
||||
)
|
||||
attn_scales = attn_scales.view(*input_shape, 1, 1)
|
||||
query_states = (query_states * attn_scales).to(query_states.dtype)
|
||||
@ -355,9 +385,15 @@ class Llama4TextAttention(FlashLlamaAttention):
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# sdpa
|
||||
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
key = key_states.view(
|
||||
bs, -1, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value = value_states.view(
|
||||
bs, -1, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key = repeat_kv(key, self.num_key_value_groups)
|
||||
value = repeat_kv(value, self.num_key_value_groups)
|
||||
|
||||
@ -402,7 +438,9 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights, layer_idx):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx)
|
||||
self.self_attn = Llama4TextAttention(
|
||||
f"{prefix}.self_attn", config, weights, layer_idx
|
||||
)
|
||||
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
|
||||
self.is_moe_layer = layer_idx in config.moe_layers
|
||||
if self.is_moe_layer: # the 128E model interleaves dense / sparse
|
||||
@ -434,7 +472,9 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
chunk_causal_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
residual = hidden_states
|
||||
hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
|
||||
@ -465,6 +505,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
hidden_states = residual + hidden_states.view(residual.shape)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama4TextModel(nn.Module):
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -473,9 +514,19 @@ class Llama4TextModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
[
|
||||
Llama4TextDecoderLayer(
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
|
||||
@ -510,7 +561,12 @@ class Llama4TextModel(nn.Module):
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask, chunk_causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False
|
||||
attention_mask,
|
||||
inputs_embeds.view(bs, int(seq_len), -1),
|
||||
cache_position,
|
||||
None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
|
||||
@ -530,7 +586,6 @@ class Llama4TextModel(nn.Module):
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -547,7 +602,10 @@ class Llama4TextModel(nn.Module):
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
|
||||
return (
|
||||
attention_mask,
|
||||
attention_mask,
|
||||
) # flash does not support chunked attn TODO support flash
|
||||
return None, None
|
||||
|
||||
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
|
||||
@ -561,7 +619,11 @@ class Llama4TextModel(nn.Module):
|
||||
if past_key_values is not None:
|
||||
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
|
||||
else:
|
||||
full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
|
||||
full_cache_length = (
|
||||
attention_mask.shape[-1]
|
||||
if attention_mask is not None
|
||||
else sequence_length
|
||||
)
|
||||
|
||||
cond1 = first_cache_position >= attention_chunk_size
|
||||
cond2 = (first_cache_position < attention_chunk_size) & (
|
||||
@ -571,7 +633,9 @@ class Llama4TextModel(nn.Module):
|
||||
torch.where(
|
||||
cond1,
|
||||
attention_chunk_size + sequence_length - 1,
|
||||
torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
|
||||
torch.where(
|
||||
cond2, first_cache_position + sequence_length, attention_chunk_size
|
||||
),
|
||||
)
|
||||
if use_cache
|
||||
else full_cache_length
|
||||
@ -586,7 +650,7 @@ class Llama4TextModel(nn.Module):
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
device=device
|
||||
device=device,
|
||||
)
|
||||
if full_cache_length > self.config.attention_chunk_size:
|
||||
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
|
||||
@ -598,24 +662,37 @@ class Llama4TextModel(nn.Module):
|
||||
device=device,
|
||||
)
|
||||
|
||||
local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well
|
||||
local_attention_mask = attention_mask[
|
||||
:, start_idx:end_idx
|
||||
] # offset here as well
|
||||
# It may be smaller than attention_chunk_size -> pad it
|
||||
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
|
||||
if requires_padding:
|
||||
local_attention_mask = nn.functional.pad(
|
||||
local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1])
|
||||
local_attention_mask,
|
||||
(0, attention_chunk_size - local_attention_mask.shape[-1]),
|
||||
)
|
||||
# Depending on the padding, take the query tokens from the end or the cache_position
|
||||
if not requires_padding:
|
||||
chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :]
|
||||
chunked_attention_mask = chunked_attention_mask[
|
||||
None, None, -sequence_length:, :
|
||||
]
|
||||
else:
|
||||
chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :]
|
||||
chunked_attention_mask = chunked_attention_mask[
|
||||
None, None, cache_position, :
|
||||
]
|
||||
|
||||
chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1)
|
||||
chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :]
|
||||
chunked_attention_mask = chunked_attention_mask.expand(
|
||||
input_tensor.shape[0], -1, -1, -1
|
||||
)
|
||||
chunked_attention_mask = (
|
||||
chunked_attention_mask * local_attention_mask[:, None, None, :]
|
||||
)
|
||||
if self.config._attn_implementation == "eager":
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype)
|
||||
chunked_attention_mask = torch.where(
|
||||
chunked_attention_mask == 0, min_dtype, 0.0
|
||||
).to(dtype)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
@ -628,10 +705,15 @@ class Llama4TextModel(nn.Module):
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||
causal_mask, min_dtype
|
||||
)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and chunked_attention_mask is not None
|
||||
):
|
||||
chunked_attention_mask = chunked_attention_mask.bool()
|
||||
causal_mask = causal_mask.bool()
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
@ -661,7 +743,8 @@ class Llama4TextModel(nn.Module):
|
||||
"""
|
||||
arange_vector = torch.arange(start, end, device=device)
|
||||
block_pos = torch.abs(
|
||||
arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size
|
||||
arange_vector.unsqueeze(0) // attention_chunk_size
|
||||
- arange_vector.unsqueeze(1) // attention_chunk_size
|
||||
)
|
||||
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
@ -706,25 +789,33 @@ class Llama4TextModel(nn.Module):
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1)
|
||||
causal_mask *= torch.arange(
|
||||
target_length, device=device
|
||||
) > cache_position.to(device).reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
causal_mask = (
|
||||
causal_mask.clone()
|
||||
) # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
|
||||
:, None, None, :
|
||||
].to(device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||
:, :, :, :mask_length
|
||||
].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
|
||||
class Llama4ForCausalLM(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
@ -791,7 +882,10 @@ class Llama4VisionMLP2(torch.nn.Module):
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again
|
||||
return self.activation_fn(
|
||||
hidden_states
|
||||
) # TODO: check if we need to apply activation again
|
||||
|
||||
|
||||
class Llama4MultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
@ -812,10 +906,15 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
|
||||
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
||||
batch_size, height, width, channels = input_tensor.size()
|
||||
reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
|
||||
reshaped_tensor = input_tensor.view(
|
||||
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
reshaped_tensor = reshaped_tensor.view(
|
||||
batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
|
||||
batch_size,
|
||||
int(height * shuffle_ratio),
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / (shuffle_ratio**2)),
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
@ -827,14 +926,19 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
|
||||
self.inner_dim = int(
|
||||
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
|
||||
)
|
||||
self.output_dim = config.projector_output_dim
|
||||
self.mlp = Llama4VisionMLP2(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.mlp = Llama4VisionMLP2(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
||||
return self.mlp(encoded_patches)
|
||||
|
||||
|
||||
# TODO there is a different RoPE for vision encoder, defined as below
|
||||
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
|
||||
ndim = query.ndim
|
||||
@ -878,7 +982,6 @@ class Llama4VisionAttention(nn.Module):
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
|
||||
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_dim * self.num_heads,
|
||||
@ -891,15 +994,21 @@ class Llama4VisionAttention(nn.Module):
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = value_states.view(hidden_shape)
|
||||
|
||||
query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)
|
||||
query_states, key_states = apply_rotary_emb(
|
||||
query_states, key_states, freqs_ci=freqs_ci
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
dropout_p=0,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
@ -920,7 +1029,6 @@ class Llama4VisionMLP(nn.Module):
|
||||
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
@ -987,10 +1095,14 @@ class Llama4VisionEncoder(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4VisionEncoderLayer(prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Llama4VisionEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
])
|
||||
]
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
self.config = config
|
||||
|
||||
@ -1024,7 +1136,6 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
@ -1037,21 +1148,31 @@ class Llama4VisionRotaryEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
# Calculate image grid indices
|
||||
idx = config.image_size // config.patch_size
|
||||
img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1)
|
||||
img_idx = torch.arange(
|
||||
idx**2, dtype=torch.int32, device=weights.device
|
||||
).reshape(idx**2, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
|
||||
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
||||
# Calculate x and y coordinates
|
||||
frequencies_x = img_idx % idx # x coordinates
|
||||
frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates
|
||||
frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
|
||||
# Calculate frequency components
|
||||
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
||||
rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim))
|
||||
rope_freq = 1.0 / (
|
||||
config.rope_theta
|
||||
** (
|
||||
torch.arange(0, freq_dim, 2, device=weights.device)[
|
||||
: (freq_dim // 2)
|
||||
].float()
|
||||
/ freq_dim
|
||||
)
|
||||
)
|
||||
|
||||
# Compute frequencies for x and y directions
|
||||
freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :])
|
||||
freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
||||
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
|
||||
freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :])
|
||||
freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
||||
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Combine frequencies and mask special tokens
|
||||
@ -1090,7 +1211,8 @@ class Llama4VisionModel(nn.Module):
|
||||
)
|
||||
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False
|
||||
weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
|
||||
@ -1126,17 +1248,26 @@ class Llama4VisionModel(nn.Module):
|
||||
|
||||
# Add cls token
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim
|
||||
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
|
||||
num_patches,
|
||||
hidden_dim,
|
||||
)
|
||||
class_embedding = self.class_embedding.expand(
|
||||
hidden_state.shape[0], 1, hidden_state.shape[-1]
|
||||
)
|
||||
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
|
||||
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
||||
num_patches += 1
|
||||
|
||||
# Position embeddings
|
||||
hidden_state = hidden_state.reshape(
|
||||
batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim
|
||||
batch_size_times_num_tiles * num_concurrent_media,
|
||||
num_chunks,
|
||||
num_patches,
|
||||
hidden_dim,
|
||||
)
|
||||
positional_embedding = self.positional_embedding_vlm.to(
|
||||
dtype=hidden_state.dtype, device=hidden_state.device
|
||||
)
|
||||
positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
|
||||
hidden_state = hidden_state + positional_embedding
|
||||
hidden_state = self.layernorm_pre(hidden_state)
|
||||
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
|
||||
@ -1156,6 +1287,7 @@ class Llama4VisionModel(nn.Module):
|
||||
hidden_state = self.vision_adapter(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class Llama4ForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
@ -1179,7 +1311,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
prefix="language_model", config=config.text_config, weights=weights
|
||||
)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.pad_token_id = (
|
||||
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
)
|
||||
self.config = config
|
||||
self.dtype = weights.dtype
|
||||
self.device = weights.device
|
||||
@ -1208,7 +1342,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
|
||||
)
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
hidden_state = self.vision_model(pixel_values)
|
||||
return hidden_state
|
||||
@ -1275,8 +1411,12 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||
-1, inputs_embeds.size(-1)
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
|
||||
logits, speculative_logits = self.text_model(
|
||||
@ -1289,7 +1429,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
hpu_attention_meta,
|
||||
adapter_data,
|
||||
lm_head_indices,
|
||||
attention_mask
|
||||
attention_mask,
|
||||
)
|
||||
|
||||
return logits, speculative_logits
|
@ -1,18 +1,31 @@
|
||||
from packaging.version import Version
|
||||
from packaging import version
|
||||
import subprocess
|
||||
|
||||
|
||||
def get_driver_version():
|
||||
"""
|
||||
Returns the driver version.
|
||||
"""
|
||||
# Enable console printing for `hl-smi` check
|
||||
output = subprocess.run(
|
||||
"hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"}
|
||||
"hl-smi",
|
||||
shell=True,
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env={"ENABLE_CONSOLE": "true"},
|
||||
)
|
||||
if output.returncode == 0 and output.stdout:
|
||||
return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0])
|
||||
return version.parse(
|
||||
output.stdout.split("\n")[2]
|
||||
.replace(" ", "")
|
||||
.split(":")[1][:-1]
|
||||
.split("-")[0]
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
|
||||
|
||||
|
||||
|
@ -315,7 +315,6 @@ class Weights:
|
||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||
block_offset += block_size
|
||||
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[tensors_slices, ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
|
Loading…
Reference in New Issue
Block a user