add ep to improve perf

qwen3 moe crash fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-07-06 18:27:51 -07:00
parent 433029e56f
commit 82a9435a80
5 changed files with 58 additions and 63 deletions

View File

@ -118,9 +118,9 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
ENV HF_HUB_ENABLE_HF_TRANSFER 1
ENV HABANA_VISIBLE_DEVICES all
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -51,10 +51,12 @@ class FP8SparseMoELayer(nn.Module):
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
else:
self.ep_offset = 0

View File

@ -7,6 +7,7 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
import os
class UnquantizedSparseMoELayer(nn.Module):
@ -39,6 +40,21 @@ class UnquantizedSparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.rank = weights.process_group.rank()
self.world_size = weights.process_group.size()
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
experts_min = self.ep_offset
experts_max = self.ep_offset + n_experts - 1
else:
self.ep_offset = 0
experts_min = 0
experts_max = n_experts - 1
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -46,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj = _load_expert_weights_row(
@ -53,9 +71,11 @@ class UnquantizedSparseMoELayer(nn.Module):
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
for i in range(n_experts):
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
@ -87,12 +107,23 @@ def _load_expert_multi_weights_col(
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
if not use_ep:
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
else:
weight = weights.get_multi_weights(
[
f"{prefix}.{i+ep_offset}.{gate_proj_name}",
f"{prefix}.{i+ep_offset}.{up_proj_name}",
],
0,
)
assert isinstance(weight, UnquantizedWeight)
@ -116,12 +147,19 @@ def _load_expert_weights_row(
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
if not use_ep:
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
else:
weight = weights.get_weights(
f"{prefix}.{i+ep_offset}.{name}",
)
assert isinstance(weight, UnquantizedWeight)

View File

@ -21,6 +21,7 @@ import torch.nn.functional as F
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -466,6 +467,10 @@ class Qwen3MoeModel(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

View File

@ -1,50 +0,0 @@
import os
from pathlib import Path
from loguru import logger
from text_generation_server import server
import argparse
from text_generation_server.utils.adapter import parse_lora_adapters
def main(args):
logger.info("TGIService: starting tgi service .... ")
logger.info(
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
args.model_id,
args.revision,
args.sharded,
args.speculate,
args.dtype,
args.trust_remote_code,
args.uds_path,
)
)
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
server.serve(
model_id=args.model_id,
lora_adapters=lora_adapters,
revision=args.revision,
sharded=args.sharded,
quantize=args.quantize,
speculate=args.speculate,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens,
kv_cache_dtype="auto",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str)
parser.add_argument("--revision", type=str)
parser.add_argument("--sharded", type=bool)
parser.add_argument("--speculate", type=int, default=None)
parser.add_argument("--dtype", type=str)
parser.add_argument("--trust_remote_code", type=bool)
parser.add_argument("--uds_path", type=Path)
parser.add_argument("--quantize", type=str)
parser.add_argument("--max_input_tokens", type=int)
args = parser.parse_args()
main(args)