mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Use SparseMoELayer
This commit is contained in:
parent
f4cadd7527
commit
c07c80fac9
@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.moe import SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
@ -252,79 +253,24 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix: str, mat, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||
|
||||
assert mat in ["w1", "w2", "w3"]
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
assert (
|
||||
config.intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
block_size = config.intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
tensor = torch.empty(
|
||||
(config.num_local_experts * block_size, config.hidden_size),
|
||||
dtype=weights.dtype,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
for i in range(config.num_local_experts):
|
||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||
|
||||
if mat == "w2":
|
||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||
else:
|
||||
expert_slice = slice_[start:stop]
|
||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
||||
dtype=weights.dtype
|
||||
).to(device=weights.device)
|
||||
return tensor
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
act = config.hidden_act
|
||||
if "gelu" in act:
|
||||
self.act = lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
elif "silu" in act:
|
||||
self.act = torch.nn.functional.silu
|
||||
else:
|
||||
self.act = ACT2FN[act]
|
||||
|
||||
# gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
self.moe = SparseMoELayer(
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.num_local_experts,
|
||||
n_expert_group=None,
|
||||
renormalize=True,
|
||||
topk=config.num_experts_per_tok,
|
||||
topk_group=None,
|
||||
weights=weights,
|
||||
gate_proj_name="w1",
|
||||
up_proj_name="w3",
|
||||
down_proj_name="w2",
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
@ -332,15 +278,7 @@ class BlockSparseMoE(nn.Module):
|
||||
def forward(self, x, adapter_data) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
out = self.moe(x, gating_output=router_logits)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
|
Loading…
Reference in New Issue
Block a user