add model

This commit is contained in:
Mohit Sharma 2025-04-01 16:11:19 +00:00
parent 06663162b4
commit 8e01191b4c
4 changed files with 42 additions and 39 deletions

View File

@ -286,17 +286,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int, flag=True):
def get_multi_weights_col(
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
if flag:
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
else:
w = [
weights.get_sharded(f"{p}", dim=2, to_device=False)
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
for p in prefixes
]
else:
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device

View File

@ -5,9 +5,10 @@ import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.log import log_master
from loguru import logger
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
@ -114,9 +115,11 @@ def _load_expert_multi_weights_col(
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight = weights.get_multi_weights_col(
[f"{prefix}.gate_up_proj"], 0, flag=False
).weight.transpose(2, 1).contiguous()
all_weight = (
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
.weight.transpose(2, 1)
.contiguous()
)
# for i in range(n_experts):
# # weight = weights.get_weights_col(
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
@ -155,9 +158,11 @@ def _load_expert_weights_row(
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight = weights.get_weights_row(
f"{prefix}.{name}", flag=False
).weight.transpose(1,2).contiguous()
all_weight = (
weights.get_weights_row(f"{prefix}.{name}", flag=False)
.weight.transpose(1, 2)
.contiguous()
)
# for i in range(n_experts):
# weight = weights.get_weights_row(
# f"{prefix}.{name}", flag=False

View File

@ -1034,23 +1034,24 @@ def get_model(
)
elif model_type == LLAMA4:
return VlmCausalLM(
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
# config_class=Gemma3Config,
# processor_class=Gemma3Processor,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# TODO: once implemented in transformers, use the config class
# and processor class from there.
# config_class=Gemma3Config,
# processor_class=Gemma3Processor,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if FLASH_TRANSFORMERS_BACKEND:
from transformers import Llama4ForConditionalGeneration as Llama4Model
return TransformersFlashVlmCausalLM.fallback(
model_id,
Llama4Model,

View File

@ -406,7 +406,7 @@ class Llama4MoE(nn.Module):
weights,
):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_size
# Gating
@ -416,7 +416,7 @@ class Llama4MoE(nn.Module):
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
renormalize=False,
topk=config.num_experts_per_tok,
topk_group=None,
scoring_func="sigmoid",
@ -434,17 +434,14 @@ class Llama4MoE(nn.Module):
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
from pdb import set_trace; set_trace()
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
from pdb import set_trace; set_trace()
out = self.moe_layer(x, gating_output=router_logits)
from pdb import set_trace; set_trace()
if shared_output is not None:
out = out + shared_output
@ -452,7 +449,7 @@ class Llama4MoE(nn.Module):
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
from pdb import set_trace; set_trace()
# from pdb import set_trace; set_trace()
return out.view(*x.shape)
@ -526,13 +523,13 @@ class Llama4Layer(nn.Module):
max_s,
adapter_data,
)
from pdb import set_trace; set_trace()
# from pdb import set_trace; set_trace()
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
from pdb import set_trace; set_trace()
# from pdb import set_trace; set_trace()
output = self.mlp(normed_attn_res_output)
@ -551,7 +548,7 @@ class Llama4Model(torch.nn.Module):
config,
weights,
)
for layer_id in range(1)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(