mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
add model
This commit is contained in:
parent
06663162b4
commit
8e01191b4c
@ -286,17 +286,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int, flag=True):
|
def get_multi_weights_col(
|
||||||
|
self, weights: "Weights", prefixes: List[str], dim: int, flag=True
|
||||||
|
):
|
||||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
if flag:
|
if flag:
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False)
|
||||||
]
|
|
||||||
else:
|
|
||||||
w = [
|
|
||||||
weights.get_sharded(f"{p}", dim=2, to_device=False)
|
|
||||||
for p in prefixes
|
for p in prefixes
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
w = [weights.get_sharded(f"{p}", dim=2, to_device=False) for p in prefixes]
|
||||||
shapes = [x.shape for x in w]
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
# Concat then send to the device
|
# Concat then send to the device
|
||||||
|
@ -5,9 +5,10 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.kernels import load_kernel
|
from text_generation_server.utils.kernels import load_kernel
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
|
||||||
elif SYSTEM == "cuda":
|
elif SYSTEM == "cuda":
|
||||||
@ -114,9 +115,11 @@ def _load_expert_multi_weights_col(
|
|||||||
weights: Weights,
|
weights: Weights,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
all_weight = weights.get_multi_weights_col(
|
all_weight = (
|
||||||
[f"{prefix}.gate_up_proj"], 0, flag=False
|
weights.get_multi_weights_col([f"{prefix}.gate_up_proj"], 0, flag=False)
|
||||||
).weight.transpose(2, 1).contiguous()
|
.weight.transpose(2, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
# for i in range(n_experts):
|
# for i in range(n_experts):
|
||||||
# # weight = weights.get_weights_col(
|
# # weight = weights.get_weights_col(
|
||||||
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
|
# # f"language_model.model.layers.0.feed_forward.experts.gate_up_proj",
|
||||||
@ -155,9 +158,11 @@ def _load_expert_weights_row(
|
|||||||
weights: Weights,
|
weights: Weights,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
all_weight = None
|
all_weight = None
|
||||||
all_weight = weights.get_weights_row(
|
all_weight = (
|
||||||
f"{prefix}.{name}", flag=False
|
weights.get_weights_row(f"{prefix}.{name}", flag=False)
|
||||||
).weight.transpose(1,2).contiguous()
|
.weight.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
# for i in range(n_experts):
|
# for i in range(n_experts):
|
||||||
# weight = weights.get_weights_row(
|
# weight = weights.get_weights_row(
|
||||||
# f"{prefix}.{name}", flag=False
|
# f"{prefix}.{name}", flag=False
|
||||||
|
@ -1051,6 +1051,7 @@ def get_model(
|
|||||||
)
|
)
|
||||||
if FLASH_TRANSFORMERS_BACKEND:
|
if FLASH_TRANSFORMERS_BACKEND:
|
||||||
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
from transformers import Llama4ForConditionalGeneration as Llama4Model
|
||||||
|
|
||||||
return TransformersFlashVlmCausalLM.fallback(
|
return TransformersFlashVlmCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
Llama4Model,
|
Llama4Model,
|
||||||
|
@ -406,7 +406,7 @@ class Llama4MoE(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
|
|
||||||
# Gating
|
# Gating
|
||||||
@ -416,7 +416,7 @@ class Llama4MoE(nn.Module):
|
|||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
n_experts=config.num_local_experts,
|
n_experts=config.num_local_experts,
|
||||||
n_expert_group=None,
|
n_expert_group=None,
|
||||||
renormalize=True,
|
renormalize=False,
|
||||||
topk=config.num_experts_per_tok,
|
topk=config.num_experts_per_tok,
|
||||||
topk_group=None,
|
topk_group=None,
|
||||||
scoring_func="sigmoid",
|
scoring_func="sigmoid",
|
||||||
@ -434,17 +434,14 @@ class Llama4MoE(nn.Module):
|
|||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
from pdb import set_trace; set_trace()
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
shared_output = self.shared_experts(x, reduce=False)
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
else:
|
else:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
|
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
from pdb import set_trace; set_trace()
|
|
||||||
|
|
||||||
out = self.moe_layer(x, gating_output=router_logits)
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
from pdb import set_trace; set_trace()
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
out = out + shared_output
|
out = out + shared_output
|
||||||
@ -452,7 +449,7 @@ class Llama4MoE(nn.Module):
|
|||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
from pdb import set_trace; set_trace()
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
return out.view(*x.shape)
|
return out.view(*x.shape)
|
||||||
|
|
||||||
@ -526,13 +523,13 @@ class Llama4Layer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
from pdb import set_trace; set_trace()
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, residual = self.post_attention_layernorm(
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
attn_output, residual
|
attn_output, residual
|
||||||
)
|
)
|
||||||
from pdb import set_trace; set_trace()
|
# from pdb import set_trace; set_trace()
|
||||||
|
|
||||||
output = self.mlp(normed_attn_res_output)
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
@ -551,7 +548,7 @@ class Llama4Model(torch.nn.Module):
|
|||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(1)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
|
Loading…
Reference in New Issue
Block a user