add model

This commit is contained in:
Mohit Sharma 2025-04-01 16:11:19 +00:00
parent 06663162b4
commit 8e01191b4c
4 changed files with 42 additions and 39 deletions

View File

@ -286,17 +286,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int, flag=True): def get_multi_weights_col(
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
if flag: if flag:
w = [ w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
]
else:
w = [
weights.get_sharded(f"{p}", dim=2, to_device=False)
for p in prefixes for p in prefixes
] ]
else:
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
shapes = [x.shape for x in w] shapes = [x.shape for x in w]
# Concat then send to the device # Concat then send to the device

View File

@ -5,9 +5,10 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import Weights
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from loguru import logger from loguru import logger
if SYSTEM == "ipex": if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda": elif SYSTEM == "cuda":
@ -114,9 +115,11 @@ def _load_expert_multi_weights_col(
weights: Weights, weights: Weights,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight = weights.get_multi_weights_col( all_weight = (
[f"{prefix}.gate_up_proj"], 0, flag=False weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
).weight.transpose(2, 1).contiguous() .weight.transpose(2, 1)
.contiguous()
)
# for i in range(n_experts): # for i in range(n_experts):
# # weight = weights.get_weights_col( # # weight = weights.get_weights_col(
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj", # # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
@ -155,9 +158,11 @@ def _load_expert_weights_row(
weights: Weights, weights: Weights,
) -> torch.Tensor: ) -> torch.Tensor:
all_weight = None all_weight = None
all_weight = weights.get_weights_row( all_weight = (
f"{prefix}.{name}", flag=False weights.get_weights_row(f"{prefix}.{name}", flag=False)
).weight.transpose(1,2).contiguous() .weight.transpose(1, 2)
.contiguous()
)
# for i in range(n_experts): # for i in range(n_experts):
# weight = weights.get_weights_row( # weight = weights.get_weights_row(
# f"{prefix}.{name}", flag=False # f"{prefix}.{name}", flag=False

View File

@ -1051,6 +1051,7 @@ def get_model(
) )
if FLASH_TRANSFORMERS_BACKEND: if FLASH_TRANSFORMERS_BACKEND:
from transformers import Llama4ForConditionalGeneration as Llama4Model from transformers import Llama4ForConditionalGeneration as Llama4Model
return TransformersFlashVlmCausalLM.fallback( return TransformersFlashVlmCausalLM.fallback(
model_id, model_id,
Llama4Model, Llama4Model,

View File

@ -406,7 +406,7 @@ class Llama4MoE(nn.Module):
weights, weights,
): ):
super().__init__() super().__init__()
self.config = config
self.hidden_dim = config.hidden_size self.hidden_dim = config.hidden_size
# Gating # Gating
@ -416,7 +416,7 @@ class Llama4MoE(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
n_experts=config.num_local_experts, n_experts=config.num_local_experts,
n_expert_group=None, n_expert_group=None,
renormalize=True, renormalize=False,
topk=config.num_experts_per_tok, topk=config.num_experts_per_tok,
topk_group=None, topk_group=None,
scoring_func="sigmoid", scoring_func="sigmoid",
@ -434,17 +434,14 @@ class Llama4MoE(nn.Module):
self.process_group = weights.process_group self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
from pdb import set_trace; set_trace()
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False) shared_output = self.shared_experts(x, reduce=False)
else: else:
shared_output = None shared_output = None
router_logits = self.gate(x) router_logits = self.gate(x)
from pdb import set_trace; set_trace()
out = self.moe_layer(x, gating_output=router_logits) out = self.moe_layer(x, gating_output=router_logits)
from pdb import set_trace; set_trace()
if shared_output is not None: if shared_output is not None:
out = out + shared_output out = out + shared_output
@ -452,7 +449,7 @@ class Llama4MoE(nn.Module):
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
from pdb import set_trace; set_trace() # from pdb import set_trace; set_trace()
return out.view(*x.shape) return out.view(*x.shape)
@ -526,13 +523,13 @@ class Llama4Layer(nn.Module):
max_s, max_s,
adapter_data, adapter_data,
) )
from pdb import set_trace; set_trace() # from pdb import set_trace; set_trace()
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm( normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual attn_output, residual
) )
from pdb import set_trace; set_trace() # from pdb import set_trace; set_trace()
output = self.mlp(normed_attn_res_output) output = self.mlp(normed_attn_res_output)
@ -551,7 +548,7 @@ class Llama4Model(torch.nn.Module):
config, config,
weights, weights,
) )
for layer_id in range(1) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(