diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 8366c25f..4043aea9 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -286,17 +286,17 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) - def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int, flag=True): + def get_multi_weights_col( + self, weights: "Weights", prefixes: List[str], dim: int, flag=True + ): # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet if flag: w = [ - weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes - ] - else: - w = [ - weights.get_sharded(f"{p}", dim=2, to_device=False) + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes ] + else: + w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes] shapes = [x.shape for x in w] # Concat then send to the device @@ -365,7 +365,7 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False) else: w = weights.get_sharded(f"{prefix}", dim=1, to_device=False) - + w = w.to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index c4eb1073..a6ef467d 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -5,9 +5,10 @@ import torch.nn as nn from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.kernels import load_kernel -from text_generation_server.utils.weights import UnquantizedWeight, Weights +from text_generation_server.utils.weights import Weights from text_generation_server.utils.log import log_master from loguru import logger + if SYSTEM == "ipex": from intel_extension_for_pytorch.llm.modules import GatedMLPMOE elif SYSTEM == "cuda": @@ -114,9 +115,11 @@ def _load_expert_multi_weights_col( weights: Weights, ) -> torch.Tensor: all_weight = None - all_weight = weights.get_multi_weights_col( - [f"{prefix}.gate_up_proj"], 0, flag=False - ).weight.transpose(2, 1).contiguous() + all_weight = ( + weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False) + .weight.transpose(2, 1) + .contiguous() + ) # for i in range(n_experts): # # weight = weights.get_weights_col( # # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj", @@ -128,7 +131,7 @@ def _load_expert_multi_weights_col( # weight = weights.get_multi_weights_col( # [f"{prefix}.gate_up_proj"], 0, flag=False # ) - + # from pdb import set_trace; set_trace() # assert isinstance(weight, UnquantizedWeight) @@ -155,9 +158,11 @@ def _load_expert_weights_row( weights: Weights, ) -> torch.Tensor: all_weight = None - all_weight = weights.get_weights_row( - f"{prefix}.{name}", flag=False - ).weight.transpose(1,2).contiguous() + all_weight = ( + weights.get_weights_row(f"{prefix}.{name}", flag=False) + .weight.transpose(1, 2) + .contiguous() + ) # for i in range(n_experts): # weight = weights.get_weights_row( # f"{prefix}.{name}", flag=False diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0c55b51b..14a59018 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1034,23 +1034,24 @@ def get_model( ) elif model_type == LLAMA4: return VlmCausalLM( - model_id=model_id, - model_class=Llama4ForConditionalGeneration, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - kv_cache_dtype=kv_cache_dtype, - # TODO: once implemented in transformers, use the config class - # and processor class from there. - # config_class=Gemma3Config, - # processor_class=Gemma3Processor, - default_dtype=torch.bfloat16, - trust_remote_code=trust_remote_code, - lora_adapter_ids=lora_adapter_ids, - ) + model_id=model_id, + model_class=Llama4ForConditionalGeneration, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + # TODO: once implemented in transformers, use the config class + # and processor class from there. + # config_class=Gemma3Config, + # processor_class=Gemma3Processor, + default_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) if FLASH_TRANSFORMERS_BACKEND: from transformers import Llama4ForConditionalGeneration as Llama4Model + return TransformersFlashVlmCausalLM.fallback( model_id, Llama4Model, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 088b4a5c..e86aec68 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -406,7 +406,7 @@ class Llama4MoE(nn.Module): weights, ): super().__init__() - + self.config = config self.hidden_dim = config.hidden_size # Gating @@ -416,7 +416,7 @@ class Llama4MoE(nn.Module): prefix=f"{prefix}.experts", n_experts=config.num_local_experts, n_expert_group=None, - renormalize=True, + renormalize=False, topk=config.num_experts_per_tok, topk_group=None, scoring_func="sigmoid", @@ -434,17 +434,14 @@ class Llama4MoE(nn.Module): self.process_group = weights.process_group def forward(self, x: torch.Tensor) -> torch.Tensor: - from pdb import set_trace; set_trace() if self.shared_experts is not None: shared_output = self.shared_experts(x, reduce=False) else: shared_output = None router_logits = self.gate(x) - from pdb import set_trace; set_trace() out = self.moe_layer(x, gating_output=router_logits) - from pdb import set_trace; set_trace() if shared_output is not None: out = out + shared_output @@ -452,7 +449,7 @@ class Llama4MoE(nn.Module): # Reduce sum if self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) - from pdb import set_trace; set_trace() + # from pdb import set_trace; set_trace() return out.view(*x.shape) @@ -526,13 +523,13 @@ class Llama4Layer(nn.Module): max_s, adapter_data, ) - from pdb import set_trace; set_trace() + # from pdb import set_trace; set_trace() # faster post attention rms norm normed_attn_res_output, residual = self.post_attention_layernorm( attn_output, residual ) - from pdb import set_trace; set_trace() + # from pdb import set_trace; set_trace() output = self.mlp(normed_attn_res_output) @@ -551,7 +548,7 @@ class Llama4Model(torch.nn.Module): config, weights, ) - for layer_id in range(1) + for layer_id in range(config.num_hidden_layers) ] ) self.norm = FastRMSNorm.load(