mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-06 01:12:06 +00:00
add deepseekv3
This commit is contained in:
parent
16162602c2
commit
57de05b0dd
@ -279,9 +279,9 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
|||||||
|
|
||||||
FROM kernel-builder AS moe-kernels
|
FROM kernel-builder AS moe-kernels
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
|
ENV MOE_KERNELS_BRANCH=0c8f8ed941635025e277d6dc6f7324827855cc40
|
||||||
ENV VLLM_TARGET_DEVICE=rocm
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
RUN git clone https://github.com/danieldk/moe-kernels.git && \
|
RUN git clone https://github.com/mht-sharma/moe-kernels.git && \
|
||||||
cd moe-kernels && \
|
cd moe-kernels && \
|
||||||
git checkout ${MOE_KERNELS_BRANCH} && \
|
git checkout ${MOE_KERNELS_BRANCH} && \
|
||||||
python setup.py install
|
python setup.py install
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple, Type, Union, List
|
from typing import Optional, Tuple, Type, Union, List
|
||||||
|
from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -170,14 +171,31 @@ def fp8_quantize(
|
|||||||
class HybridFP8UnquantLoader(WeightsLoader):
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scale_ub: Optional[float],
|
||||||
|
to_fp8: bool,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
self.activation_scale_ub = activation_scale_ub
|
self.activation_scale_ub = activation_scale_ub
|
||||||
self.to_fp8 = to_fp8
|
self.to_fp8 = to_fp8
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
def get_weights(self, weights: "Weights", prefix: str):
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_tensor(f"{prefix}.weight")
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||||
|
if scale.device == torch.device("cpu"):
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
@ -266,6 +284,22 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
|
scale = torch.cat(scale, dim=dim)
|
||||||
|
if scale.device == torch.device("cpu"):
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
scale = [
|
scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
@ -311,6 +345,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||||
|
if scale.device == torch.device("cpu"):
|
||||||
|
scale = scale.to(weights.device)
|
||||||
|
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
@ -345,6 +392,7 @@ class Fp8Weight(Weight):
|
|||||||
input_scale: Optional[torch.Tensor] = None
|
input_scale: Optional[torch.Tensor] = None
|
||||||
activation_scale_ub: Optional[float] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
force_w8a16: bool = False
|
force_w8a16: bool = False
|
||||||
|
weight_block_size: Optional[List[int]] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
@ -361,6 +409,7 @@ class Fp8Weight(Weight):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
input_scale=self.input_scale,
|
input_scale=self.input_scale,
|
||||||
scale_upper_bound=self.activation_scale_ub,
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
|
weight_block_size=self.weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -375,6 +424,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
scale_upper_bound: Optional[float] = None,
|
scale_upper_bound: Optional[float] = None,
|
||||||
|
weight_block_size: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
@ -388,6 +438,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale.float()
|
self.scale = scale.float()
|
||||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||||
self.scale_upper_bound = torch.tensor(
|
self.scale_upper_bound = torch.tensor(
|
||||||
@ -421,6 +472,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
) -> "Fp8Linear":
|
) -> "Fp8Linear":
|
||||||
input_scale = kwargs.get("input_scale", None)
|
input_scale = kwargs.get("input_scale", None)
|
||||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
@ -429,6 +481,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
scale_upper_bound=scale_upper_bound,
|
scale_upper_bound=scale_upper_bound,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -440,6 +493,21 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
return cls._device_identity_cache[device]
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||||
|
# logger.info(f"qinput: {qinput.shape} {scale.shape} {self.qweight.shape} {self.scale.shape} {self.weight_block_size}")
|
||||||
|
output = w8a8_block_fp8_matmul(
|
||||||
|
qinput,
|
||||||
|
self.qweight,
|
||||||
|
scale,
|
||||||
|
self.scale,
|
||||||
|
self.weight_block_size,
|
||||||
|
output_dtype=input.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output.to(dtype=input.dtype)
|
||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
|
@ -214,6 +214,8 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
@ -256,6 +258,8 @@ class SparseMoELayer(nn.Module):
|
|||||||
topk=topk,
|
topk=topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
gate_proj_name=gate_proj_name,
|
gate_proj_name=gate_proj_name,
|
||||||
up_proj_name=up_proj_name,
|
up_proj_name=up_proj_name,
|
||||||
down_proj_name=down_proj_name,
|
down_proj_name=down_proj_name,
|
||||||
|
@ -24,6 +24,8 @@ class FP8SparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
@ -38,6 +40,9 @@ class FP8SparseMoELayer(nn.Module):
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
(
|
(
|
||||||
self.gate_up_proj,
|
self.gate_up_proj,
|
||||||
@ -72,6 +77,8 @@ class FP8SparseMoELayer(nn.Module):
|
|||||||
use_grouped_topk=self.n_expert_group is not None,
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=self.gate_up_proj_weight_scale,
|
w1_scale=self.gate_up_proj_weight_scale,
|
||||||
w2_scale=self.down_proj_weight_scale,
|
w2_scale=self.down_proj_weight_scale,
|
||||||
@ -105,7 +112,7 @@ def _load_expert_weights(
|
|||||||
)
|
)
|
||||||
if all_weight_scales is None:
|
if all_weight_scales is None:
|
||||||
all_weight_scales = torch.empty(
|
all_weight_scales = torch.empty(
|
||||||
(n_experts,),
|
(n_experts,) + weight.weight_scale.shape,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=weight.weight.device,
|
device=weight.weight.device,
|
||||||
)
|
)
|
||||||
|
@ -23,6 +23,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
topk: int,
|
topk: int,
|
||||||
topk_group: Optional[int],
|
topk_group: Optional[int],
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
scoring_func: Optional[str] = "softmax",
|
||||||
|
e_score_correction_bias: Optional[float] = None,
|
||||||
gate_proj_name: str = "gate_proj",
|
gate_proj_name: str = "gate_proj",
|
||||||
up_proj_name: str = "up_proj",
|
up_proj_name: str = "up_proj",
|
||||||
down_proj_name: str = "down_proj",
|
down_proj_name: str = "down_proj",
|
||||||
@ -37,6 +39,9 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.renormalize = renormalize
|
self.renormalize = renormalize
|
||||||
|
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
|
||||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
@ -68,7 +73,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
)
|
)
|
||||||
|
# from loguru import logger
|
||||||
|
# logger.info("Fused MoE is used here")
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
x,
|
x,
|
||||||
w1=self.gate_up_proj,
|
w1=self.gate_up_proj,
|
||||||
@ -80,6 +86,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
|||||||
use_grouped_topk=self.n_expert_group is not None,
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
num_expert_group=self.n_expert_group,
|
num_expert_group=self.n_expert_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,6 +87,10 @@ try:
|
|||||||
FlashDeepseekV2ForCausalLM,
|
FlashDeepseekV2ForCausalLM,
|
||||||
DeepseekV2Config,
|
DeepseekV2Config,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||||
|
FlashDeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
@ -185,6 +189,11 @@ class ModelType(enum.Enum):
|
|||||||
"name": "Deepseek V2",
|
"name": "Deepseek V2",
|
||||||
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||||
}
|
}
|
||||||
|
DEEPSEEK_V3 = {
|
||||||
|
"type": "deepseek_v3",
|
||||||
|
"name": "Deepseek V3",
|
||||||
|
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||||
|
}
|
||||||
IDEFICS2 = {
|
IDEFICS2 = {
|
||||||
"type": "idefics2",
|
"type": "idefics2",
|
||||||
"name": "Idefics 2",
|
"name": "Idefics 2",
|
||||||
@ -632,6 +641,40 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif model_type == DEEPSEEK_V3:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
head_size = max(
|
||||||
|
config_dict.get("qk_nope_dim", 128)
|
||||||
|
+ config_dict.get("qk_rope_dim", 64),
|
||||||
|
config_dict.get("v_head_dim", 128),
|
||||||
|
)
|
||||||
|
return FlashCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=FlashDeepseekV3ForCausalLM,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
config_class=DeepseekV3Config,
|
||||||
|
head_size=head_size,
|
||||||
|
)
|
||||||
|
elif sharded:
|
||||||
|
raise NotImplementedError(
|
||||||
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V3")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif model_type == MAMBA:
|
elif model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -0,0 +1,676 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
FastLinear,
|
||||||
|
SpeculativeHead,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=102400,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
moe_intermediate_size=1407,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=32,
|
||||||
|
n_shared_experts=2,
|
||||||
|
n_routed_experts=160,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method="gready",
|
||||||
|
n_group=8,
|
||||||
|
topk_group=3,
|
||||||
|
num_experts_per_tok=6,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=0,
|
||||||
|
norm_topk_prob=False,
|
||||||
|
scoring_func="softmax",
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
seq_aux=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=100000,
|
||||||
|
eos_token_id=100001,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.seq_aux = seq_aux
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_size != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights: Weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||||
|
self.value_head_size = config.v_head_dim
|
||||||
|
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.qk_rope_head_dim,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
mscale = get_mscale(
|
||||||
|
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_a_proj = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = get_linear(
|
||||||
|
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||||
|
bias=(
|
||||||
|
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||||
|
if config.attention_bias
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||||
|
|
||||||
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
if self.q_lora_rank is None:
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
else:
|
||||||
|
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
_, query_pe = torch.split(
|
||||||
|
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
compressed_kv, key_pe = torch.split(
|
||||||
|
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||||
|
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||||
|
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
key_nope, value = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, heads, head_dim = query_pe.shape
|
||||||
|
query_pe = (
|
||||||
|
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
batch_size, heads, head_dim = key_pe.shape
|
||||||
|
key_pe = (
|
||||||
|
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||||
|
.transpose(2, 3)
|
||||||
|
.reshape(batch_size, heads, head_dim)
|
||||||
|
)
|
||||||
|
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||||
|
|
||||||
|
query[..., self.qk_nope_head_dim :] = query_pe
|
||||||
|
key = torch.empty_like(query)
|
||||||
|
key[..., : self.qk_nope_head_dim] = key_nope
|
||||||
|
key[..., self.qk_nope_head_dim :] = key_pe
|
||||||
|
|
||||||
|
# We need to pad the heads because Flash Attention does not support
|
||||||
|
# qk and v with different head sizes.
|
||||||
|
query = torch.nn.functional.pad(
|
||||||
|
query, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
key = torch.nn.functional.pad(
|
||||||
|
key, (0, self.head_pad_size - self.head_size), value=0
|
||||||
|
)
|
||||||
|
value = torch.nn.functional.pad(
|
||||||
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache.store(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
slots=slots,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attn_output = attention(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
seqlen=seqlen,
|
||||||
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query,
|
||||||
|
kv_cache,
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
kv_scales=self.kv_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove padding.
|
||||||
|
attn_output = attn_output[..., : self.value_head_size]
|
||||||
|
|
||||||
|
return self.o_proj(
|
||||||
|
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_act = config.hidden_act
|
||||||
|
if self.hidden_act != "silu":
|
||||||
|
# Bail out because MoE only supports silu.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||||
|
)
|
||||||
|
self.act = ACT2FN[self.hidden_act]
|
||||||
|
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||||
|
|
||||||
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
and not self.quantize
|
||||||
|
):
|
||||||
|
out = torch.empty(
|
||||||
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out, reduce=reduce)
|
||||||
|
else:
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV3Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.moe_intermediate_size = (
|
||||||
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
|
# Gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
if config.topk_method == "noaux_tc":
|
||||||
|
self.gate.e_score_correction_bias = torch.zeros(
|
||||||
|
config.n_routed_experts, device=weights.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
|
self.moe_layer = moe_layer_cls(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
scoring_func=config.scoring_func,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
|
if config.n_shared_experts is not None:
|
||||||
|
self.shared_experts = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.moe_intermediate_size
|
||||||
|
* config.n_shared_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_experts = None
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.shared_experts is not None:
|
||||||
|
shared_output = self.shared_experts(x, reduce=False)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
|
|
||||||
|
router_logits = self.gate(x)
|
||||||
|
|
||||||
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
out = out + shared_output
|
||||||
|
|
||||||
|
# Reduce sum
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Layer(nn.Module):
|
||||||
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV3Attention(
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
):
|
||||||
|
moe_layer_cls = (
|
||||||
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV3MLP(
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: torch.Tensor,
|
||||||
|
kv_cache,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||||
|
attn_output, residual
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV3Layer(
|
||||||
|
prefix,
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, prefix: str, config, weights: Weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DeepseekV3Model(
|
||||||
|
"model" if not prefix else f"{prefix}.model", config, weights
|
||||||
|
)
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
|
||||||
@ -20,6 +20,7 @@ class _QuantizerConfig:
|
|||||||
groupsize: int
|
groupsize: int
|
||||||
quant_method: str
|
quant_method: str
|
||||||
sym: bool
|
sym: bool
|
||||||
|
weight_block_size: Optional[List[int]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -49,11 +50,11 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format = None
|
checkpoint_format = None
|
||||||
sym = False
|
sym = False
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
weight_block_size = None
|
||||||
|
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
try:
|
try:
|
||||||
data = _get_config_json(model_id, revision, filename)
|
data = _get_config_json(model_id, revision, filename)
|
||||||
|
|
||||||
# FP8 config
|
# FP8 config
|
||||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||||
return _FP8QuantizerConfig(
|
return _FP8QuantizerConfig(
|
||||||
@ -66,12 +67,16 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
elif "sym" in data["quantization_config"]:
|
elif "sym" in data["quantization_config"]:
|
||||||
sym = data["quantization_config"]["sym"]
|
sym = data["quantization_config"]["sym"]
|
||||||
|
|
||||||
bits = data["quantization_config"]["bits"]
|
if "bits" in data["quantization_config"]:
|
||||||
groupsize = data["quantization_config"]["group_size"]
|
bits = data["quantization_config"]["bits"]
|
||||||
|
if "group_size" in data["quantization_config"]:
|
||||||
|
groupsize = data["quantization_config"]["group_size"]
|
||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
quant_method = data["quantization_config"]["quant_method"]
|
quant_method = data["quantization_config"]["quant_method"]
|
||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
checkpoint_format = data["quantization_config"].get("checkpoint_format", None)
|
||||||
desc_act = data["quantization_config"]["desc_act"]
|
if desc_act in data["quantization_config"]:
|
||||||
|
desc_act = data["quantization_config"]["desc_act"]
|
||||||
|
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
@ -107,6 +112,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format=checkpoint_format,
|
checkpoint_format=checkpoint_format,
|
||||||
sym=sym,
|
sym=sym,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -196,9 +202,14 @@ def get_loader(
|
|||||||
# Since the default for the quantize config is _QuantizerConfig,
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
# we need to add this check to not get an attribute error
|
# we need to add this check to not get an attribute error
|
||||||
activation_scale_ub = None
|
activation_scale_ub = None
|
||||||
|
weight_block_size = quantizer_config.weight_block_size
|
||||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
return HybridFP8UnquantLoader(
|
||||||
|
activation_scale_ub,
|
||||||
|
to_fp8=quantize == "fp8",
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
@ -149,6 +149,10 @@ class Weights:
|
|||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
# if filename.as_posix().endswith("l.safetensors"):
|
||||||
|
# from loguru import logger
|
||||||
|
# logger.info(f"Skipping {filename}")
|
||||||
|
# continue
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
if k in routing:
|
if k in routing:
|
||||||
|
Loading…
Reference in New Issue
Block a user