[gaudi] Deepseek v2 mla and add ep to unquantized moe (#3287)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-07-07 17:29:39 +08:00 committed by GitHub
parent 778b61c0da
commit ebb26f0ccd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 171 additions and 107 deletions

View File

@ -118,9 +118,9 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
ENV HF_HUB_ENABLE_HF_TRANSFER 1
ENV HABANA_VISIBLE_DEVICES all
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

View File

@ -51,10 +51,12 @@ class FP8SparseMoELayer(nn.Module):
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
else:
self.ep_offset = 0

View File

@ -7,6 +7,7 @@ from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
import os
class UnquantizedSparseMoELayer(nn.Module):
@ -39,6 +40,21 @@ class UnquantizedSparseMoELayer(nn.Module):
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.rank = weights.process_group.rank()
self.world_size = weights.process_group.size()
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
experts_min = self.ep_offset
experts_max = self.ep_offset + n_experts - 1
else:
self.ep_offset = 0
experts_min = 0
experts_max = n_experts - 1
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -46,6 +62,8 @@ class UnquantizedSparseMoELayer(nn.Module):
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj = _load_expert_weights_row(
@ -53,9 +71,11 @@ class UnquantizedSparseMoELayer(nn.Module):
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
for i in range(n_experts):
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
@ -87,12 +107,23 @@ def _load_expert_multi_weights_col(
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
if not use_ep:
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
else:
weight = weights.get_multi_weights(
[
f"{prefix}.{i+ep_offset}.{gate_proj_name}",
f"{prefix}.{i+ep_offset}.{up_proj_name}",
],
0,
)
assert isinstance(weight, UnquantizedWeight)
@ -116,12 +147,19 @@ def _load_expert_weights_row(
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
if not use_ep:
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
else:
weight = weights.get_weights(
f"{prefix}.{i+ep_offset}.{name}",
)
assert isinstance(weight, UnquantizedWeight)

View File

@ -28,11 +28,12 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
@ -44,6 +45,18 @@ from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV2Config(PretrainedConfig):
def __init__(
self,
@ -246,6 +259,45 @@ class DeepseekV2Attention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
@ -258,14 +310,9 @@ class DeepseekV2Attention(torch.nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
hidden_states_or_q_c = hidden_states
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
@ -273,13 +320,18 @@ class DeepseekV2Attention(torch.nn.Module):
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
@ -294,33 +346,47 @@ class DeepseekV2Attention(torch.nn.Module):
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=key,
value=value,
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention
attn_output = attention(
query=query,
@ -331,9 +397,15 @@ class DeepseekV2Attention(torch.nn.Module):
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
attn_output = paged_attention(
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
@ -341,14 +413,10 @@ class DeepseekV2Attention(torch.nn.Module):
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV2MLP(nn.Module):

View File

@ -21,6 +21,7 @@ import torch.nn.functional as F
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
@ -466,6 +467,10 @@ class Qwen3MoeModel(nn.Module):
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds

View File

@ -1606,7 +1606,7 @@ class FlashCausalLM(Model):
):
self.kv_cache = []
empty_cache()
if self.config.model_type == "deepseek_v3":
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
self.kv_cache = [
KVCompressCache(
num_blocks=num_blocks,
@ -1646,7 +1646,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
if self.config.model_type == "deepseek_v3":
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)

View File

@ -1,50 +0,0 @@
import os
from pathlib import Path
from loguru import logger
from text_generation_server import server
import argparse
from text_generation_server.utils.adapter import parse_lora_adapters
def main(args):
logger.info("TGIService: starting tgi service .... ")
logger.info(
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
args.model_id,
args.revision,
args.sharded,
args.speculate,
args.dtype,
args.trust_remote_code,
args.uds_path,
)
)
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
server.serve(
model_id=args.model_id,
lora_adapters=lora_adapters,
revision=args.revision,
sharded=args.sharded,
quantize=args.quantize,
speculate=args.speculate,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens,
kv_cache_dtype="auto",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str)
parser.add_argument("--revision", type=str)
parser.add_argument("--sharded", type=bool)
parser.add_argument("--speculate", type=int, default=None)
parser.add_argument("--dtype", type=str)
parser.add_argument("--trust_remote_code", type=bool)
parser.add_argument("--uds_path", type=Path)
parser.add_argument("--quantize", type=str)
parser.add_argument("--max_input_tokens", type=int)
args = parser.parse_args()
main(args)

View File

@ -184,6 +184,7 @@ Text Generation Inference enables serving optimized models on Gaudi hardware. Th
**Large Language Models (LLMs)**
- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)
- [deepseek-v2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)
- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)