Add deepseekv3 (#2968)

* Add fp8 support moe models

add deepseekv3

format codfe'

update dockerfile

update doc

* Small modifications.

* Moe kernels 0.8.1

* Upgrade to 0.8.1

* Fixing moe import.

* Black.

* Apply suggestions from code review

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>

* Fixing Mixtral + Nits.

* Put link to ref.

* Fix other call locations.

* Scoring func `softmax` is the only one that works.

---------

Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
This commit is contained in:
Nicolas Patry 2025-01-30 16:40:25 +01:00 committed by GitHub
parent 80e7d98f88
commit cb747b33da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 864 additions and 26 deletions

View File

@ -279,7 +279,7 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
FROM kernel-builder AS moe-kernels
WORKDIR /usr/src
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
ENV MOE_KERNELS_BRANCH=d7e042bf9f7aff10c631212fc71b24895d66eb59
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \

View File

@ -4,6 +4,7 @@
Text Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1738163501,
"narHash": "sha256-MW+HVo3Kjr/W8ra7qyeG2nW/Z6fsZ7nDfWs3Uvw9Xko=",
"lastModified": 1738229197,
"narHash": "sha256-K/YJSFhzP0vN23GMfM1HVMtSzaM488hh12ggsMtKMG0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "bfdd9594c7d99cf8442e06f3bb2b4ab08185affe",
"rev": "cfcddaf3044f59c3fbd335935ac3c0e9f458d824",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "moe-kernels-0.8.0",
"ref": "moe_0_8_1",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/moe-kernels-0.8.0";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/moe_0_8_1";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {

View File

@ -1635,6 +1635,7 @@ enum Gpu {
A40,
H100,
A100,
H200,
Unknown(String),
}
@ -1661,6 +1662,7 @@ impl From<&str> for Gpu {
"nvidia-a100-sxm4-40gb" => Gpu::A100,
"nvidia-a100-80gb-pcie" => Gpu::A100,
"nvidia-a100" => Gpu::A100,
"nvidia-h200" => Gpu::H200,
card => Gpu::Unknown(card.to_string()),
}
}
@ -1678,6 +1680,7 @@ impl std::fmt::Display for Gpu {
Gpu::A40 => write!(f, "nvidia-a40"),
Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"),
Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"),
Gpu::H200 => write!(f, "nvida-h200"),
Gpu::Unknown(card) => write!(f, "{}", card),
}
}
@ -1702,11 +1705,13 @@ impl ComputeType {
// https://www.nvidia.com/en-us/data-center/a40/
// https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf
Gpu::A40 => Some(149 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Gpu::A100 => Some(312 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/h100/
// https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf
Gpu::H100 => Some(900 * 10u64.pow(12)),
// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
Gpu::A100 => Some(312 * 10u64.pow(12)),
// https://www.nvidia.com/en-us/data-center/h200/
Gpu::H200 => Some(989 * 10u64.pow(12)),
Gpu::Unknown(card) => {
tracing::warn!("Unkown compute for card {card}");
None

View File

@ -224,6 +224,8 @@ pub enum Config {
Qwen2,
Opt,
T5,
DeepseekV2,
DeepseekV3,
}
#[derive(Clone, Debug, Serialize, Deserialize)]

View File

@ -75,7 +75,7 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.7/marlin_kernels-0.3.7+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.0/moe_kernels-0.8.0+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
moe-kernels.url = "https://github.com/danieldk/moe-kernels/releases/download/v0.8.1/moe_kernels-0.8.1+cu123torch2.5-cp39-abi3-linux_x86_64.whl"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -19,6 +19,12 @@ try:
except ImportError:
marlin_kernels = None
try:
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
except ImportError:
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
@ -38,7 +44,6 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
"""
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
# Marlin is W8A16, use it when:
#
@ -180,14 +185,29 @@ def fp8_quantize(
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
@ -276,6 +296,21 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
@ -321,6 +356,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if SYSTEM == "cuda":
@ -355,6 +402,7 @@ class Fp8Weight(Weight):
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
@ -371,6 +419,7 @@ class Fp8Weight(Weight):
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
)
@ -385,6 +434,7 @@ class Fp8Linear(torch.nn.Module):
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
if CUTLASS_FP8_AVAILABLE:
@ -398,6 +448,7 @@ class Fp8Linear(torch.nn.Module):
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor(
@ -431,6 +482,7 @@ class Fp8Linear(torch.nn.Module):
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls(
qweight=weight,
@ -439,6 +491,7 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
@classmethod
@ -450,6 +503,25 @@ class Fp8Linear(torch.nn.Module):
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
)
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize(

View File

@ -52,6 +52,8 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ...
def forward(
@ -81,9 +83,14 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
@ -199,21 +206,24 @@ class SparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
weights.loader.weight_class, UnquantizedWeight
):
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, HybridFP8UnquantLoader):
cls = (
FP8SparseMoELayer
if weights.loader.to_fp8
else UnquantizedSparseMoELayer
)
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
elif isinstance(
weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm(
@ -240,6 +250,8 @@ class SparseMoELayer(nn.Module):
topk=topk,
topk_group=topk_group,
weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,

View File

@ -28,6 +28,8 @@ class FP8SparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
@ -42,6 +44,9 @@ class FP8SparseMoELayer(nn.Module):
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
(
self.gate_up_proj,
@ -76,6 +81,8 @@ class FP8SparseMoELayer(nn.Module):
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
@ -109,7 +116,7 @@ def _load_expert_weights(
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,),
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)

View File

@ -69,7 +69,11 @@ class GPTQMarlinSparseMoELayer(nn.Module):
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
assert scoring_func == "softmax", f"scoring func {scoring_func} is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
super().__init__()
if not (

View File

@ -23,6 +23,8 @@ class UnquantizedSparseMoELayer(nn.Module):
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
@ -37,6 +39,9 @@ class UnquantizedSparseMoELayer(nn.Module):
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
@ -68,7 +73,6 @@ class UnquantizedSparseMoELayer(nn.Module):
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
return fused_moe(
x,
w1=self.gate_up_proj,
@ -80,6 +84,8 @@ class UnquantizedSparseMoELayer(nn.Module):
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)

View File

@ -89,6 +89,10 @@ try:
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
@ -195,6 +199,11 @@ class ModelType(enum.Enum):
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
@ -642,6 +651,40 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == DEEPSEEK_V3:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V3")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA:
return Mamba(
model_id,

View File

@ -0,0 +1,676 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -81,12 +81,14 @@ def initialize_torch_distributed():
pg_options=options,
)
else:
device = torch.device(f"cuda:{RANK}")
torch.distributed.init_process_group(
backend=backend,
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=120),
pg_options=options,
device_id=device,
)
else:
logger.warning("torch.distributed is already initialized.")

View File

@ -1,7 +1,7 @@
import json
import os
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List
from huggingface_hub import hf_hub_download
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
@ -20,6 +20,7 @@ class _QuantizerConfig:
groupsize: int
quant_method: str
sym: bool
weight_block_size: Optional[List[int]]
@dataclass
@ -49,16 +50,17 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None
sym = False
desc_act = False
weight_block_size = None
filename = "config.json"
try:
data = _get_config_json(model_id, revision, filename)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
@ -107,6 +109,7 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
weight_block_size=weight_block_size,
)
@ -196,9 +199,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else:
raise ValueError(f"Unknown quantization method: {quantize}")