Use SparseMoELayer

This commit is contained in:
Daniël de Kok 2024-09-19 12:28:29 +00:00
parent f4cadd7527
commit c07c80fac9

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -252,79 +253,24 @@ class FlashLlamaAttention(torch.nn.Module):
)
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.w13 = torch.cat([w1, w3], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts", "w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
self.moe = SparseMoELayer(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)
self.process_group = weights.process_group
@ -332,15 +278,7 @@ class BlockSparseMoE(nn.Module):
def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = fused_moe(
x,
self.w13,
self.w2,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1: