text-generation-inference/backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Wang, Yi d658b5def3
Deepseek R1 for Gaudi backend (#3211)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-05-19 16:36:39 +02:00

269 lines
8.9 KiB
Python

from typing import Optional
import torch
import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
)
from text_generation_server.layers.moe.fused_moe import select_experts
import habana_frameworks.torch as htorch
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
else:
self.ep_offset = 0
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
)
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=gating_output,
use_grouped_topk=self.n_expert_group is not None,
top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)