diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index d8ea0077..98b5d6a7 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -850,14 +850,17 @@ def get_model( ) from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + adapt_transformers_to_gaudi() if SDP_ON_BF16 == 1: torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) if model_type == "gpt_bigcode": from text_generation_server.models.starcoder import StarCoder + return StarCoder(model_id=model_id, revision=revision, dtype=dtype) if model_type == "bloom": from text_generation_server.models.bloom import BLOOM + return BLOOM( model_id=model_id, revision=revision, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 2f9b5aeb..ca354934 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union, Type +from typing import List, Optional, Tuple, Union import torch import math @@ -47,7 +47,6 @@ from text_generation_server.layers.attention import ( HPUPagedAttentionMetadata, ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - load_attention, FlashLlamaAttention, LlamaMLP, ) @@ -58,6 +57,7 @@ def reshape_for_broadcast(freqs: torch.Tensor, target): shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)] return freqs.view(*shape) + def apply_rotary_emb( query: torch.Tensor, key: torch.Tensor, @@ -65,12 +65,12 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: query_shape = query.shape key_shape = key.shape - cos_emb,sin_emb = freqs_ci.split(1, dim=-1) - + cos_emb, sin_emb = freqs_ci.split(1, dim=-1) + if len(query.shape) == 3: query = query.unsqueeze(0) key = key.unsqueeze(0) - + query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2) key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2) q_shape = query_reshaped.shape[:-1] @@ -78,12 +78,12 @@ def apply_rotary_emb( sin_emb = reshape_for_broadcast(sin_emb, q_shape) x_q, y_q = query_reshaped.unbind(-1) x_k, y_k = key_reshaped.unbind(-1) - + x_q_rot = x_q * cos_emb - y_q * sin_emb y_q_rot = x_q * sin_emb + y_q * cos_emb x_k_rot = x_k * cos_emb - y_k * sin_emb y_k_rot = x_k * sin_emb + y_k * cos_emb - + query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2) key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2) query_out = query_out.view(*query_shape) @@ -99,7 +99,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -108,11 +110,18 @@ class Llama4TextExperts(nn.Module): super().__init__() self.process_group = weights.process_group self.num_experts = config.num_local_experts - self.intermediate_size = config.intermediate_size // weights.process_group.size() + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False) - self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False) + self.gate_up_proj = nn.Parameter( + weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), + requires_grad=False, + ) + self.down_proj = nn.Parameter( + weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -128,18 +137,18 @@ class Llama4TextExperts(nn.Module): Returns: torch.Tensor """ - gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) + gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim) down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors next_states = torch.bmm((up * self.act_fn(gate)), down_proj) next_states = next_states.view(-1, self.hidden_size) - + # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(next_states, group=self.process_group) - + return next_states @@ -171,9 +180,7 @@ class Llama4TextMLP(nn.Module): def forward(self, x): gate_up_states = self.gate_up_proj(x) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1] - ) + return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]) class Llama4TextL2Norm(torch.nn.Module): @@ -202,11 +209,17 @@ class Llama4TextMoe(nn.Module): self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) - self.router = FastLinear.load(config=config, prefix=f"{prefix}.router", weights=weights, bias=False) - self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) + self.experts = Llama4TextExperts( + config=config, prefix=f"{prefix}.experts", weights=weights + ) + self.router = FastLinear.load( + config=config, prefix=f"{prefix}.router", weights=weights, bias=False + ) + self.shared_expert = Llama4TextMLP( + config=config, prefix=f"{prefix}.shared_expert", weights=weights + ) self.process_group = weights.process_group - + def forward(self, hidden_states, adapter_data): seq_len, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_dim) @@ -215,16 +228,19 @@ class Llama4TextMoe(nn.Module): router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( - torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) ) # We do this to make sure we have -inf for non topK tokens before going through the ! # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! router_indices = ( - torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) + torch.arange(tokens_per_expert, device=hidden_states.device) + .view(1, -1) + .expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) routed_in = torch.gather( input=hidden_states, @@ -232,16 +248,17 @@ class Llama4TextMoe(nn.Module): index=router_indices, ).to(hidden_states.device) - # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) out = self.shared_expert(hidden_states) - + # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! - out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)) + out.scatter_add_( + dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim) + ) return out @@ -262,10 +279,15 @@ class Llama4TextRotaryEmbedding(nn.Module): self.original_inv_freq = self.inv_freq def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) position_ids_expanded = position_ids[:, None, :].float() - origin_device = x.device - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) inv_freq_expanded = inv_freq_expanded.to(device_type) position_ids_expanded = position_ids_expanded.to(device_type) with torch.autocast(device_type=device_type, enabled=False): # Force float32 @@ -273,7 +295,7 @@ class Llama4TextRotaryEmbedding(nn.Module): cos = torch.cos(freqs) * self.attention_scaling sin = torch.sin(freqs) * self.attention_scaling cos = cos.reshape(-1, 1, cos.shape[-1]) - sin = sin.reshape(-1, 1, sin.shape[-1]) + sin = sin.reshape(-1, 1, sin.shape[-1]) freqs_cis = torch.cat([cos, sin], dim=-1) * self.attention_scaling freqs_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) return freqs_cis @@ -286,15 +308,19 @@ class Llama4TextAttention(FlashLlamaAttention): super().__init__(layer_idx, prefix, config, weights) self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) self.scaling = self.head_dim**-0.5 self.attn_scale = config.attn_scale self.floor_scale = config.floor_scale self.attn_temperature_tuning = config.attn_temperature_tuning self.attention_dropout = config.attention_dropout self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - + if self.config.use_qk_norm and self.use_rope: self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) @@ -323,7 +349,7 @@ class Llama4TextAttention(FlashLlamaAttention): ], dim=-1, ) - + query_states = query_states.view(hidden_shape) key_states = key_states.view(hidden_shape) value_states = value_states.view(hidden_shape) @@ -347,7 +373,11 @@ class Llama4TextAttention(FlashLlamaAttention): # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers if self.attn_temperature_tuning and not self.use_rope: attn_scales = ( - torch.log(torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 + torch.log( + torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0 + ) + * self.attn_scale + + 1.0 ) attn_scales = attn_scales.view(*input_shape, 1, 1) query_states = (query_states * attn_scales).to(query_states.dtype) @@ -355,9 +385,15 @@ class Llama4TextAttention(FlashLlamaAttention): # Prefill if cu_seqlen_prefill is not None: # sdpa - query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(1, 2) - key = key_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value = value_states.view(bs, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + key = key_states.view( + bs, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value = value_states.view( + bs, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) key = repeat_kv(key, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups) @@ -395,14 +431,16 @@ class Llama4TextAttention(FlashLlamaAttention): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output, adapter_data) - return attn_output + return attn_output class Llama4TextDecoderLayer(nn.Module): def __init__(self, prefix, config, weights, layer_idx): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Llama4TextAttention(f"{prefix}.self_attn", config, weights, layer_idx) + self.self_attn = Llama4TextAttention( + f"{prefix}.self_attn", config, weights, layer_idx + ) self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope self.is_moe_layer = layer_idx in config.moe_layers if self.is_moe_layer: # the 128E model interleaves dense / sparse @@ -411,15 +449,15 @@ class Llama4TextDecoderLayer(nn.Module): self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights) self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) self.post_attention_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) def forward( self, @@ -434,7 +472,9 @@ class Llama4TextDecoderLayer(nn.Module): chunk_causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: residual = hidden_states hidden_states, _ = self.input_layernorm(hidden_states) @@ -465,6 +505,7 @@ class Llama4TextDecoderLayer(nn.Module): hidden_states = residual + hidden_states.view(residual.shape) return hidden_states + class Llama4TextModel(nn.Module): def __init__(self, prefix, config, weights): @@ -473,12 +514,22 @@ class Llama4TextModel(nn.Module): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = TensorParallelEmbedding(prefix=f"{prefix}.embed_tokens", weights=weights) - self.layers = nn.ModuleList( - [Llama4TextDecoderLayer(prefix=f"{prefix}.layers.{layer_idx}", config=config, weights=weights, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights ) - - #self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights) + self.layers = nn.ModuleList( + [ + Llama4TextDecoderLayer( + prefix=f"{prefix}.layers.{layer_idx}", + config=config, + weights=weights, + layer_idx=layer_idx, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights) self.norm = FastRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, @@ -500,7 +551,7 @@ class Llama4TextModel(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - + hidden_states = inputs_embeds bs = seqlen.input_lengths.shape[0] seq_len = inputs_embeds.shape[0] / bs @@ -508,11 +559,16 @@ class Llama4TextModel(nn.Module): if position_ids is None: position_ids = cache_position.unsqueeze(0) - + causal_mask, chunk_causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds.view(bs, int(seq_len), -1), cache_position, None, output_attentions=False, use_cache=False + attention_mask, + inputs_embeds.view(bs, int(seq_len), -1), + cache_position, + None, + output_attentions=False, + use_cache=False, ) - + freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1)) for i, layer in enumerate(self.layers): @@ -530,9 +586,8 @@ class Llama4TextModel(nn.Module): hpu_attention_meta=hpu_attention_meta, ) - hidden_states, _ = self.norm(hidden_states) - + return hidden_states def _update_causal_mask( @@ -547,7 +602,10 @@ class Llama4TextModel(nn.Module): ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask, attention_mask # flash does not support chunked attn TODO support flash + return ( + attention_mask, + attention_mask, + ) # flash does not support chunked attn TODO support flash return None, None if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: @@ -561,7 +619,11 @@ class Llama4TextModel(nn.Module): if past_key_values is not None: full_cache_length = past_key_values.get_max_cache_shape() or sequence_length else: - full_cache_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length + full_cache_length = ( + attention_mask.shape[-1] + if attention_mask is not None + else sequence_length + ) cond1 = first_cache_position >= attention_chunk_size cond2 = (first_cache_position < attention_chunk_size) & ( @@ -571,7 +633,9 @@ class Llama4TextModel(nn.Module): torch.where( cond1, attention_chunk_size + sequence_length - 1, - torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + torch.where( + cond2, first_cache_position + sequence_length, attention_chunk_size + ), ) if use_cache else full_cache_length @@ -586,7 +650,7 @@ class Llama4TextModel(nn.Module): dtype=dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], - device=device + device=device, ) if full_cache_length > self.config.attention_chunk_size: start_idx = max(first_cache_position - attention_chunk_size + 1, 0) @@ -598,24 +662,37 @@ class Llama4TextModel(nn.Module): device=device, ) - local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well + local_attention_mask = attention_mask[ + :, start_idx:end_idx + ] # offset here as well # It may be smaller than attention_chunk_size -> pad it requires_padding = local_attention_mask.shape[-1] < attention_chunk_size if requires_padding: local_attention_mask = nn.functional.pad( - local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) + local_attention_mask, + (0, attention_chunk_size - local_attention_mask.shape[-1]), ) # Depending on the padding, take the query tokens from the end or the cache_position if not requires_padding: - chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] + chunked_attention_mask = chunked_attention_mask[ + None, None, -sequence_length:, : + ] else: - chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] + chunked_attention_mask = chunked_attention_mask[ + None, None, cache_position, : + ] - chunked_attention_mask = chunked_attention_mask.expand(input_tensor.shape[0], -1, -1, -1) - chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] + chunked_attention_mask = chunked_attention_mask.expand( + input_tensor.shape[0], -1, -1, -1 + ) + chunked_attention_mask = ( + chunked_attention_mask * local_attention_mask[:, None, None, :] + ) if self.config._attn_implementation == "eager": min_dtype = torch.finfo(dtype).min - chunked_attention_mask = torch.where(chunked_attention_mask == 0, min_dtype, 0.0).to(dtype) + chunked_attention_mask = torch.where( + chunked_attention_mask == 0, min_dtype, 0.0 + ).to(dtype) if ( self.config._attn_implementation == "sdpa" @@ -628,10 +705,15 @@ class Llama4TextModel(nn.Module): # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: + if ( + self.config._attn_implementation == "sdpa" + and chunked_attention_mask is not None + ): chunked_attention_mask = chunked_attention_mask.bool() causal_mask = causal_mask.bool() if AttentionMaskConverter._ignore_causal_mask_sdpa( @@ -661,7 +743,8 @@ class Llama4TextModel(nn.Module): """ arange_vector = torch.arange(start, end, device=device) block_pos = torch.abs( - arange_vector.unsqueeze(0) // attention_chunk_size - arange_vector.unsqueeze(1) // attention_chunk_size + arange_vector.unsqueeze(0) // attention_chunk_size + - arange_vector.unsqueeze(1) // attention_chunk_size ) token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) mask = (block_pos == 0) & (token_pos <= 0) @@ -706,25 +789,33 @@ class Llama4TextModel(nn.Module): else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.to(device).reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(device) padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) return causal_mask - class Llama4ForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -751,7 +842,7 @@ class Llama4ForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = self.model( inputs_embeds, @@ -791,7 +882,10 @@ class Llama4VisionMLP2(torch.nn.Module): hidden_states = self.activation_fn(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = self.fc2(hidden_states) - return self.activation_fn(hidden_states) # TODO: check if we need to apply activation again + return self.activation_fn( + hidden_states + ) # TODO: check if we need to apply activation again + class Llama4MultiModalProjector(nn.Module): def __init__(self, prefix, config, weights): @@ -812,10 +906,15 @@ def pixel_shuffle(input_tensor, shuffle_ratio): input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() - reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) + reshaped_tensor = input_tensor.view( + batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) + ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.view( - batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) + batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2)), ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() @@ -827,14 +926,19 @@ class Llama4VisionPixelShuffleMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio - self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2)) + self.inner_dim = int( + config.projector_input_dim // (self.pixel_shuffle_ratio**2) + ) self.output_dim = config.projector_output_dim - self.mlp = Llama4VisionMLP2(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = Llama4VisionMLP2( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) return self.mlp(encoded_patches) + # TODO there is a different RoPE for vision encoder, defined as below def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): ndim = query.ndim @@ -877,8 +981,7 @@ class Llama4VisionAttention(nn.Module): hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.qkv_proj(hidden_states) - - + query_states, key_states, value_states = qkv.split( [ self.head_dim * self.num_heads, @@ -890,18 +993,24 @@ class Llama4VisionAttention(nn.Module): query_states = query_states.view(hidden_shape) key_states = key_states.view(hidden_shape) value_states = value_states.view(hidden_shape) - - query_states, key_states = apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) - + + query_states, key_states = apply_rotary_emb( + query_states, key_states, freqs_ci=freqs_ci + ) + query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False, dropout_p=0 + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + dropout_p=0, ) - + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -920,7 +1029,6 @@ class Llama4VisionMLP(nn.Module): prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) @@ -987,17 +1095,21 @@ class Llama4VisionEncoder(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config - self.layers = nn.ModuleList([ - Llama4VisionEncoderLayer(prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights) - for layer_id in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Llama4VisionEncoderLayer( + prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights + ) + for layer_id in range(config.num_hidden_layers) + ] + ) self.gradient_checkpointing = False self.config = config def forward( self, hidden_states: torch.Tensor, - freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around + freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutput]: @@ -1024,7 +1136,6 @@ class Llama4UnfoldConvolution(nn.Module): config=config, prefix=f"{prefix}.linear", weights=weights, bias=False ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) hidden_states = hidden_states.permute(0, 2, 1) @@ -1037,27 +1148,37 @@ class Llama4VisionRotaryEmbedding(nn.Module): super().__init__() # Calculate image grid indices idx = config.image_size // config.patch_size - img_idx = torch.arange(idx**2, dtype=torch.int32, device=weights.device).reshape(idx**2, 1) + img_idx = torch.arange( + idx**2, dtype=torch.int32, device=weights.device + ).reshape(idx**2, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx[-1, -1] = -2 # ID_CLS_TOKEN # Calculate x and y coordinates frequencies_x = img_idx % idx # x coordinates - frequencies_y = torch.div(img_idx, idx, rounding_mode='floor') # y coordinates + frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates # Calculate frequency components freq_dim = config.hidden_size // config.num_attention_heads // 2 - rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2, device=weights.device)[: (freq_dim // 2)].float() / freq_dim)) - + rope_freq = 1.0 / ( + config.rope_theta + ** ( + torch.arange(0, freq_dim, 2, device=weights.device)[ + : (freq_dim // 2) + ].float() + / freq_dim + ) + ) + # Compute frequencies for x and y directions - freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]) + freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :] freqs_x = freqs_x.repeat_interleave(2, dim=-1) - freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]) + freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :] freqs_y = freqs_y.repeat_interleave(2, dim=-1) - + # Combine frequencies and mask special tokens freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) - + freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 @@ -1090,9 +1211,10 @@ class Llama4VisionModel(nn.Module): ) self.positional_embedding_vlm = nn.Parameter( - weights.get_tensor(f"{prefix}.positional_embedding_vlm"), requires_grad=False + weights.get_tensor(f"{prefix}.positional_embedding_vlm"), + requires_grad=False, ) - + self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights) # layer norms @@ -1126,22 +1248,31 @@ class Llama4VisionModel(nn.Module): # Add cls token hidden_state = hidden_state.reshape( - batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim + batch_size_times_num_tiles * num_concurrent_media * num_chunks, + num_patches, + hidden_dim, + ) + class_embedding = self.class_embedding.expand( + hidden_state.shape[0], 1, hidden_state.shape[-1] ) - class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 # Position embeddings hidden_state = hidden_state.reshape( - batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim + batch_size_times_num_tiles * num_concurrent_media, + num_chunks, + num_patches, + hidden_dim, + ) + positional_embedding = self.positional_embedding_vlm.to( + dtype=hidden_state.dtype, device=hidden_state.device ) - positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) hidden_state = hidden_state + positional_embedding hidden_state = self.layernorm_pre(hidden_state) hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) freqs_ci = self.rotary_embedding(pixel_values) - + hidden_state = self.model( hidden_state, attention_mask=None, @@ -1156,6 +1287,7 @@ class Llama4VisionModel(nn.Module): hidden_state = self.vision_adapter(hidden_state) return hidden_state + class Llama4ForConditionalGeneration(nn.Module): def __init__(self, prefix: str, config, weights): @@ -1170,16 +1302,18 @@ class Llama4ForConditionalGeneration(nn.Module): self.vision_model = Llama4VisionModel( prefix="vision_model", config=config.vision_config, weights=weights ) - + self.multi_modal_projector = Llama4MultiModalProjector( prefix="multi_modal_projector", config=config, weights=weights ) - + self.text_model = Llama4ForCausalLM( prefix="language_model", config=config.text_config, weights=weights ) self.vocab_size = config.text_config.vocab_size - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) self.config = config self.dtype = weights.dtype self.device = weights.device @@ -1208,7 +1342,9 @@ class Llama4ForConditionalGeneration(nn.Module): image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ if vision_feature_select_strategy not in ["default", "full"]: - raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}") + raise ValueError( + f"Unexpected select feature strategy: {self.vision_feature_select_strategy}" + ) kwargs = {k: v for k, v in kwargs.items() if v is not None} hidden_state = self.vision_model(pixel_values) return hidden_state @@ -1232,7 +1368,7 @@ class Llama4ForConditionalGeneration(nn.Module): adapter_data: Optional[torch.Tensor] = None, **lm_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - + def _get_padding_mask(input_ids, pad_token_id=0): return (input_ids != pad_token_id).long() @@ -1275,11 +1411,15 @@ class Llama4ForConditionalGeneration(nn.Module): f"but multi_modal_projector returned {projected_vision_flat.size(0)}" ) - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat) + expanded_mask = final_mask_1d.unsqueeze(-1).expand( + -1, inputs_embeds.size(-1) + ) + inputs_embeds = inputs_embeds.masked_scatter( + expanded_mask, projected_vision_flat + ) inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) - logits, speculative_logits= self.text_model( + logits, speculative_logits = self.text_model( inputs_embeds, position_ids, cu_seqlen_prefill, @@ -1289,7 +1429,7 @@ class Llama4ForConditionalGeneration(nn.Module): hpu_attention_meta, adapter_data, lm_head_indices, - attention_mask + attention_mask, ) - return logits, speculative_logits \ No newline at end of file + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/utils/version.py b/backends/gaudi/server/text_generation_server/utils/version.py index 380b5dac..74b53dfa 100644 --- a/backends/gaudi/server/text_generation_server/utils/version.py +++ b/backends/gaudi/server/text_generation_server/utils/version.py @@ -1,18 +1,31 @@ from packaging.version import Version from packaging import version import subprocess + + def get_driver_version(): """ Returns the driver version. """ # Enable console printing for `hl-smi` check output = subprocess.run( - "hl-smi", shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"ENABLE_CONSOLE": "true"} + "hl-smi", + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={"ENABLE_CONSOLE": "true"}, ) if output.returncode == 0 and output.stdout: - return version.parse(output.stdout.split("\n")[2].replace(" ", "").split(":")[1][:-1].split("-")[0]) + return version.parse( + output.stdout.split("\n")[2] + .replace(" ", "") + .split(":")[1][:-1] + .split("-")[0] + ) return None + MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0") diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index a16b503a..dec22942 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -315,7 +315,6 @@ class Weights: tensors_slices += range(block_offset + start, block_offset + stop) block_offset += block_size - if dim == 0: tensor = slice_[tensors_slices, ...] elif dim == 1 or dim == -2: