diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py index 583c4ad0..5e4bc7fa 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py @@ -344,27 +344,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return final_hidden_states -# @use_kernel_forward_from_hub("RMSNorm") -# class Qwen3MoeRMSNorm(nn.Module): -# def __init__(self, hidden_size, eps=1e-6): -# """ -# Qwen3MoeRMSNorm is equivalent to T5LayerNorm -# """ -# super().__init__() -# self.weight = nn.Parameter(torch.ones(hidden_size)) -# self.variance_epsilon = eps - -# def forward(self, hidden_states): -# input_dtype = hidden_states.dtype -# hidden_states = hidden_states.to(torch.float32) -# variance = hidden_states.pow(2).mean(-1, keepdim=True) -# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) -# return self.weight * hidden_states.to(input_dtype) - -# def extra_repr(self): -# return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class Qwen3MoeDecoderLayer(nn.Module): def __init__(self, config, prefix, weights, layer_idx: int): super().__init__() @@ -508,239 +487,6 @@ class Qwen3MoeModel(nn.Module): return hidden_states -# def _update_causal_mask( -# self, -# attention_mask: Union[torch.Tensor, "BlockMask"], -# input_tensor: torch.Tensor, -# cache_position: torch.Tensor, -# past_key_values: Cache, -# output_attentions: bool = False, -# ): -# if self.config._attn_implementation == "flash_attention_2": -# if attention_mask is not None and past_key_values is not None: -# is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] -# if is_padding_right: -# raise ValueError( -# "You are attempting to perform batched generation with padding_side='right'" -# " this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to " -# " call `tokenizer.padding_side = 'left'` before tokenizing the input. " -# ) -# if attention_mask is not None and 0.0 in attention_mask: -# return attention_mask -# return None - -# # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in -# # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail -# # to infer the attention mask. -# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 -# using_static_cache = isinstance(past_key_values, StaticCache) -# using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - -# # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward -# if ( -# self.config._attn_implementation == "sdpa" -# and not (using_static_cache or using_sliding_window_cache) -# and not output_attentions -# ): -# if AttentionMaskConverter._ignore_causal_mask_sdpa( -# attention_mask, -# inputs_embeds=input_tensor, -# past_key_values_length=past_seen_tokens, -# sliding_window=self.config.sliding_window, -# is_training=self.training, -# ): -# return None - -# dtype = input_tensor.dtype -# min_dtype = torch.finfo(dtype).min -# sequence_length = input_tensor.shape[1] -# # SlidingWindowCache or StaticCache -# if using_sliding_window_cache or using_static_cache: -# target_length = past_key_values.get_max_cache_shape() -# # DynamicCache or no cache -# else: -# target_length = ( -# attention_mask.shape[-1] -# if isinstance(attention_mask, torch.Tensor) -# else past_seen_tokens + sequence_length + 1 -# ) - -# # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). -# causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( -# attention_mask, -# sequence_length=sequence_length, -# target_length=target_length, -# dtype=dtype, -# cache_position=cache_position, -# batch_size=input_tensor.shape[0], -# config=self.config, -# past_key_values=past_key_values, -# ) - -# if ( -# self.config._attn_implementation == "sdpa" -# and attention_mask is not None -# and attention_mask.device.type in ["cuda", "xpu", "npu"] -# and not output_attentions -# ): -# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when -# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. -# # Details: https://github.com/pytorch/pytorch/issues/110213 -# causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - -# return causal_mask - -# @staticmethod -# def _prepare_4d_causal_attention_mask_with_cache_position( -# attention_mask: torch.Tensor, -# sequence_length: int, -# target_length: int, -# dtype: torch.dtype, -# cache_position: torch.Tensor, -# batch_size: int, -# config: Qwen3MoeConfig, -# past_key_values: Cache, -# ): -# """ -# Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape -# `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - -# Args: -# attention_mask (`torch.Tensor`): -# A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. -# sequence_length (`int`): -# The sequence length being processed. -# target_length (`int`): -# The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. -# dtype (`torch.dtype`): -# The dtype to use for the 4D attention mask. -# cache_position (`torch.Tensor`): -# Indices depicting the position of the input sequence tokens in the sequence. -# batch_size (`torch.Tensor`): -# Batch size. -# config (`Qwen3MoeConfig`): -# The model's configuration class -# past_key_values (`Cache`): -# The cache class that is being used currently to generate -# """ -# if attention_mask is not None and attention_mask.dim() == 4: -# # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. -# causal_mask = attention_mask -# else: -# min_dtype = torch.finfo(dtype).min -# causal_mask = torch.full( -# (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device -# ) -# diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( -# -1, 1 -# ) -# if config.get_text_config().sliding_window is not None: -# # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also -# # the check is needed to verify is current checkpoint was trained with sliding window or not -# if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: -# sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( -# cache_position.reshape(-1, 1) - config.get_text_config().sliding_window -# ) -# diagonal_attend_mask.bitwise_or_(sliding_attend_mask) -# causal_mask *= diagonal_attend_mask -# causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) -# if attention_mask is not None: -# causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit -# if attention_mask.shape[-1] > target_length: -# attention_mask = attention_mask[:, :target_length] -# mask_length = attention_mask.shape[-1] -# padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( -# causal_mask.device -# ) -# padding_mask = padding_mask == 0 -# causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( -# padding_mask, min_dtype -# ) -# return causal_mask - - -# def load_balancing_loss_func( -# gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], -# num_experts: Optional[int] = None, -# top_k=2, -# attention_mask: Optional[torch.Tensor] = None, -# ) -> Union[torch.Tensor, int]: -# r""" -# Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - -# See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss -# function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between -# experts is too unbalanced. - -# Args: -# gate_logits: -# Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of -# shape [batch_size X sequence_length, num_experts]. -# num_experts: -# Number of experts -# top_k: -# The number of experts to route per-token, can be also interpreted as the `top-k` routing -# parameter. -# attention_mask (`torch.Tensor`, *optional*): -# The attention_mask used in forward function -# shape [batch_size X sequence_length] if not None. - -# Returns: -# The auxiliary loss. -# """ -# if gate_logits is None or not isinstance(gate_logits, tuple): -# return 0 - -# if isinstance(gate_logits, tuple): -# compute_device = gate_logits[0].device -# concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) - -# routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) - -# _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - -# expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) - -# if attention_mask is None: -# # Compute the percentage of tokens routed to each experts -# tokens_per_expert = torch.mean(expert_mask.float(), dim=0) - -# # Compute the average probability of routing to these experts -# router_prob_per_expert = torch.mean(routing_weights, dim=0) -# else: -# batch_size, sequence_length = attention_mask.shape -# num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) - -# # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask -# expert_attention_mask = ( -# attention_mask[None, :, :, None, None] -# .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) -# .reshape(-1, top_k, num_experts) -# .to(compute_device) -# ) - -# # Compute the percentage of tokens routed to each experts -# tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( -# expert_attention_mask, dim=0 -# ) - -# # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert -# router_per_expert_attention_mask = ( -# attention_mask[None, :, :, None] -# .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) -# .reshape(-1, num_experts) -# .to(compute_device) -# ) - -# # Compute the average probability of routing to these experts -# router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( -# router_per_expert_attention_mask, dim=0 -# ) - -# overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) -# return overall_loss * num_experts - - class Qwen3MoeForCausalLM(nn.Module): def __init__(self, prefix: str, config, weights):