mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
add model
This commit is contained in:
parent
06663162b4
commit
8e01191b4c
@ -286,17 +286,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int, flag=True):
|
||||
def get_multi_weights_col(
|
||||
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
|
||||
):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
if flag:
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
else:
|
||||
w = [
|
||||
weights.get_sharded(f"{p}", dim=2, to_device=False)
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
else:
|
||||
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
@ -365,7 +365,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_device=False)
|
||||
else:
|
||||
w = weights.get_sharded(f"{prefix}", dim=1, to_device=False)
|
||||
|
||||
|
||||
w = w.to(weights.device)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
|
@ -5,9 +5,10 @@ import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.kernels import load_kernel
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||
elif SYSTEM == "cuda":
|
||||
@ -114,9 +115,11 @@ def _load_expert_multi_weights_col(
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.gate_up_proj"], 0, flag=False
|
||||
).weight.transpose(2, 1).contiguous()
|
||||
all_weight = (
|
||||
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
|
||||
.weight.transpose(2, 1)
|
||||
.contiguous()
|
||||
)
|
||||
# for i in range(n_experts):
|
||||
# # weight = weights.get_weights_col(
|
||||
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
|
||||
@ -128,7 +131,7 @@ def _load_expert_multi_weights_col(
|
||||
# weight = weights.get_multi_weights_col(
|
||||
# [f"{prefix}.gate_up_proj"], 0, flag=False
|
||||
# )
|
||||
|
||||
|
||||
# from pdb import set_trace; set_trace()
|
||||
# assert isinstance(weight, UnquantizedWeight)
|
||||
|
||||
@ -155,9 +158,11 @@ def _load_expert_weights_row(
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight = weights.get_weights_row(
|
||||
f"{prefix}.{name}", flag=False
|
||||
).weight.transpose(1,2).contiguous()
|
||||
all_weight = (
|
||||
weights.get_weights_row(f"{prefix}.{name}", flag=False)
|
||||
.weight.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
# for i in range(n_experts):
|
||||
# weight = weights.get_weights_row(
|
||||
# f"{prefix}.{name}", flag=False
|
||||
|
@ -1034,23 +1034,24 @@ def get_model(
|
||||
)
|
||||
elif model_type == LLAMA4:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Llama4ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# TODO: once implemented in transformers, use the config class
|
||||
# and processor class from there.
|
||||
# config_class=Gemma3Config,
|
||||
# processor_class=Gemma3Processor,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
model_id=model_id,
|
||||
model_class=Llama4ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# TODO: once implemented in transformers, use the config class
|
||||
# and processor class from there.
|
||||
# config_class=Gemma3Config,
|
||||
# processor_class=Gemma3Processor,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
if FLASH_TRANSFORMERS_BACKEND:
|
||||
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
||||
|
||||
return TransformersFlashVlmCausalLM.fallback(
|
||||
model_id,
|
||||
Llama4Model,
|
||||
|
@ -406,7 +406,7 @@ class Llama4MoE(nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.hidden_dim = config.hidden_size
|
||||
|
||||
# Gating
|
||||
@ -416,7 +416,7 @@ class Llama4MoE(nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.num_local_experts,
|
||||
n_expert_group=None,
|
||||
renormalize=True,
|
||||
renormalize=False,
|
||||
topk=config.num_experts_per_tok,
|
||||
topk_group=None,
|
||||
scoring_func="sigmoid",
|
||||
@ -434,17 +434,14 @@ class Llama4MoE(nn.Module):
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from pdb import set_trace; set_trace()
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(x, reduce=False)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
router_logits = self.gate(x)
|
||||
from pdb import set_trace; set_trace()
|
||||
|
||||
out = self.moe_layer(x, gating_output=router_logits)
|
||||
from pdb import set_trace; set_trace()
|
||||
|
||||
if shared_output is not None:
|
||||
out = out + shared_output
|
||||
@ -452,7 +449,7 @@ class Llama4MoE(nn.Module):
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
from pdb import set_trace; set_trace()
|
||||
# from pdb import set_trace; set_trace()
|
||||
|
||||
return out.view(*x.shape)
|
||||
|
||||
@ -526,13 +523,13 @@ class Llama4Layer(nn.Module):
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
from pdb import set_trace; set_trace()
|
||||
# from pdb import set_trace; set_trace()
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||
attn_output, residual
|
||||
)
|
||||
from pdb import set_trace; set_trace()
|
||||
# from pdb import set_trace; set_trace()
|
||||
|
||||
output = self.mlp(normed_attn_res_output)
|
||||
|
||||
@ -551,7 +548,7 @@ class Llama4Model(torch.nn.Module):
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(1)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
|
Loading…
Reference in New Issue
Block a user