From 57de05b0dd1d6de67fcf8af457b5a8cb64396fe4 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 28 Jan 2025 17:05:53 +0000 Subject: [PATCH] add deepseekv3 --- Dockerfile_amd | 4 +- server/text_generation_server/layers/fp8.py | 70 +- .../layers/moe/__init__.py | 4 + .../text_generation_server/layers/moe/fp8.py | 9 +- .../layers/moe/unquantized.py | 10 +- .../text_generation_server/models/__init__.py | 43 ++ .../flash_deepseek_v3_modeling.py | 676 ++++++++++++++++++ .../utils/quantization.py | 25 +- .../text_generation_server/utils/weights.py | 4 + 9 files changed, 833 insertions(+), 12 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py diff --git a/Dockerfile_amd b/Dockerfile_amd index 293cac2d..10001012 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -279,9 +279,9 @@ RUN git clone https://github.com/danieldk/marlin-kernels.git && \ FROM kernel-builder AS moe-kernels WORKDIR /usr/src -ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd +ENV MOE_KERNELS_BRANCH=0c8f8ed941635025e277d6dc6f7324827855cc40 ENV VLLM_TARGET_DEVICE=rocm -RUN git clone https://github.com/danieldk/moe-kernels.git && \ +RUN git clone https://github.com/mht-sharma/moe-kernels.git && \ cd moe-kernels && \ git checkout ${MOE_KERNELS_BRANCH} && \ python setup.py install diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index c4df0213..8328ad30 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import os from typing import Optional, Tuple, Type, Union, List +from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8 import torch from loguru import logger @@ -170,14 +171,31 @@ def fp8_quantize( class HybridFP8UnquantLoader(WeightsLoader): """Weight loader that loads FP8 and unquantized Torch tensors.""" - def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + def __init__( + self, + activation_scale_ub: Optional[float], + to_fp8: bool, + weight_block_size: Optional[List[int]] = None, + ): self.activation_scale_ub = activation_scale_ub self.to_fp8 = to_fp8 + self.weight_block_size = weight_block_size def get_weights(self, weights: "Weights", prefix: str): w = weights.get_tensor(f"{prefix}.weight") if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = weights.get_tensor(f"{prefix}.weight_scale_inv") + if scale.device == torch.device("cpu"): + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) # FP8 branch scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) @@ -266,6 +284,22 @@ class HybridFP8UnquantLoader(WeightsLoader): # FP8 branch if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = [ + weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=dim) + if scale.device == torch.device("cpu"): + scale = scale.to(weights.device) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + scale = [ _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) for p, shape in zip(prefixes, shapes) @@ -311,6 +345,19 @@ class HybridFP8UnquantLoader(WeightsLoader): w = weights.get_sharded(f"{prefix}.weight", dim=1) # FP8 branch if w.dtype == torch.float8_e4m3fn: + if self.weight_block_size is not None: + scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1) + if scale.device == torch.device("cpu"): + scale = scale.to(weights.device) + + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + weight_block_size=self.weight_block_size, + ) + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) if SYSTEM == "cuda": @@ -345,6 +392,7 @@ class Fp8Weight(Weight): input_scale: Optional[torch.Tensor] = None activation_scale_ub: Optional[float] = None force_w8a16: bool = False + weight_block_size: Optional[List[int]] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: @@ -361,6 +409,7 @@ class Fp8Weight(Weight): bias=bias, input_scale=self.input_scale, scale_upper_bound=self.activation_scale_ub, + weight_block_size=self.weight_block_size, ) @@ -375,6 +424,7 @@ class Fp8Linear(torch.nn.Module): bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, scale_upper_bound: Optional[float] = None, + weight_block_size: Optional[List[int]] = None, ) -> None: super().__init__() if CUTLASS_FP8_AVAILABLE: @@ -388,6 +438,7 @@ class Fp8Linear(torch.nn.Module): self.qweight = qweight self.scale = scale.float() self.input_scale = input_scale.float() if input_scale is not None else None + self.weight_block_size = weight_block_size if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: self.scale_upper_bound = torch.tensor( @@ -421,6 +472,7 @@ class Fp8Linear(torch.nn.Module): ) -> "Fp8Linear": input_scale = kwargs.get("input_scale", None) scale_upper_bound = kwargs.get("scale_upper_bound", None) + weight_block_size = kwargs.get("weight_block_size", None) return cls( qweight=weight, @@ -429,6 +481,7 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound=scale_upper_bound, bias=bias, dtype=dtype, + weight_block_size=weight_block_size, ) @classmethod @@ -440,6 +493,21 @@ class Fp8Linear(torch.nn.Module): return cls._device_identity_cache[device] def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.weight_block_size is not None: + qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) + # logger.info(f"qinput: {qinput.shape} {scale.shape} {self.qweight.shape} {self.scale.shape} {self.weight_block_size}") + output = w8a8_block_fp8_matmul( + qinput, + self.qweight, + scale, + self.scale, + self.weight_block_size, + output_dtype=input.dtype, + ) + + if self.bias is not None: + output = output + self.bias + return output.to(dtype=input.dtype) if CUTLASS_FP8_AVAILABLE: # cutlass FP8 supports per-token scales, so get non-scalar scales. qinput, scale = fp8_quantize( diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index ae9ca6fc..c0d0b38d 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -214,6 +214,8 @@ class SparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -256,6 +258,8 @@ class SparseMoELayer(nn.Module): topk=topk, topk_group=topk_group, weights=weights, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, gate_proj_name=gate_proj_name, up_proj_name=up_proj_name, down_proj_name=down_proj_name, diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py index 4d516fd4..ec438f73 100644 --- a/server/text_generation_server/layers/moe/fp8.py +++ b/server/text_generation_server/layers/moe/fp8.py @@ -24,6 +24,8 @@ class FP8SparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -38,6 +40,9 @@ class FP8SparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias ( self.gate_up_proj, @@ -72,6 +77,8 @@ class FP8SparseMoELayer(nn.Module): use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, use_fp8_w8a8=True, w1_scale=self.gate_up_proj_weight_scale, w2_scale=self.down_proj_weight_scale, @@ -105,7 +112,7 @@ def _load_expert_weights( ) if all_weight_scales is None: all_weight_scales = torch.empty( - (n_experts,), + (n_experts,) + weight.weight_scale.shape, dtype=torch.float32, device=weight.weight.device, ) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 32326653..a104ca5d 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -23,6 +23,8 @@ class UnquantizedSparseMoELayer(nn.Module): topk: int, topk_group: Optional[int], weights: Weights, + scoring_func: Optional[str] = "softmax", + e_score_correction_bias: Optional[float] = None, gate_proj_name: str = "gate_proj", up_proj_name: str = "up_proj", down_proj_name: str = "down_proj", @@ -37,6 +39,9 @@ class UnquantizedSparseMoELayer(nn.Module): self.topk = topk self.topk_group = topk_group self.renormalize = renormalize + self.weight_block_size = weights.weights_loader.weight_block_size + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias self.gate_up_proj = _load_expert_multi_weights_col( prefix=prefix, @@ -68,7 +73,8 @@ class UnquantizedSparseMoELayer(nn.Module): num_expert_group=self.n_expert_group, topk_group=self.topk_group, ) - + # from loguru import logger + # logger.info("Fused MoE is used here") return fused_moe( x, w1=self.gate_up_proj, @@ -80,6 +86,8 @@ class UnquantizedSparseMoELayer(nn.Module): use_grouped_topk=self.n_expert_group is not None, num_expert_group=self.n_expert_group, topk_group=self.topk_group, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e2d24643..613932a5 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -87,6 +87,10 @@ try: FlashDeepseekV2ForCausalLM, DeepseekV2Config, ) + from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import ( + FlashDeepseekV3ForCausalLM, + DeepseekV3Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -185,6 +189,11 @@ class ModelType(enum.Enum): "name": "Deepseek V2", "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", } + DEEPSEEK_V3 = { + "type": "deepseek_v3", + "name": "Deepseek V3", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V3", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -632,6 +641,40 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == DEEPSEEK_V3: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV3ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV3Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V3") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == MAMBA: return Mamba( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py new file mode 100644 index 00000000..25fb3cc2 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -0,0 +1,676 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Type + +import torch +import torch.distributed +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + Seqlen, + attention, + paged_attention, +) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights + +if SYSTEM == "rocm": + try: + import vllm._custom_ops as ops + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_ops`. Full error: {e}") + + +class DeepseekV3Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class DeepseekV3Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + ) + + self.kv_scales = get_kv_scales(weights, f"{prefix}") + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: KVCache, + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + kv_scales=self.kv_scales, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache, + self.kv_head_mapping, + self.softmax_scale, + block_tables, + seqlen, + max_s, + kv_scales=self.kv_scales, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV3MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class DeepseekV3MoE(nn.Module): + def __init__( + self, + prefix, + config: DeepseekV3Config, + moe_layer_cls: Type[MoELayer], + weights, + ): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = torch.zeros( + config.n_routed_experts, device=weights.device + ) + else: + self.gate.e_score_correction_bias = None + + self.moe_layer = moe_layer_cls( + prefix=f"{prefix}.experts", + n_experts=config.n_routed_experts, + n_expert_group=config.n_group, + renormalize=config.norm_topk_prob, + topk=config.num_experts_per_tok, + topk_group=config.topk_group, + weights=weights, + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + assert isinstance(self.moe_layer, MoELayer) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV3MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + + out = self.moe_layer(x, gating_output=router_logits) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DeepseekV3Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV3Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_layer_cls = ( + SparseMoELayer + if SparseMoELayer.is_supported(weights) + else DenseMoELayer + ) + self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights) + else: + self.mlp = DeepseekV3MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV3Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV3Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + seqlen, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV3ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV3Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + seqlen: Seqlen, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 0d894939..764ffab4 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -1,7 +1,7 @@ import json import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from huggingface_hub import hf_hub_download from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin @@ -20,6 +20,7 @@ class _QuantizerConfig: groupsize: int quant_method: str sym: bool + weight_block_size: Optional[List[int]] @dataclass @@ -49,11 +50,11 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = None sym = False desc_act = False + weight_block_size = None filename = "config.json" try: data = _get_config_json(model_id, revision, filename) - # FP8 config if data["quantization_config"]["quant_method"] == "fbgemm_fp8": return _FP8QuantizerConfig( @@ -66,12 +67,16 @@ def _get_quantizer_config(model_id, revision): elif "sym" in data["quantization_config"]: sym = data["quantization_config"]["sym"] - bits = data["quantization_config"]["bits"] - groupsize = data["quantization_config"]["group_size"] + if "bits" in data["quantization_config"]: + bits = data["quantization_config"]["bits"] + if "group_size" in data["quantization_config"]: + groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] - checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + checkpoint_format = data["quantization_config"].get("checkpoint_format", None) + if desc_act in data["quantization_config"]: + desc_act = data["quantization_config"]["desc_act"] + weight_block_size = data["quantization_config"].get("weight_block_size", None) except Exception: filename = "quantize_config.json" try: @@ -107,6 +112,7 @@ def _get_quantizer_config(model_id, revision): checkpoint_format=checkpoint_format, sym=sym, desc_act=desc_act, + weight_block_size=weight_block_size, ) @@ -196,9 +202,14 @@ def get_loader( # Since the default for the quantize config is _QuantizerConfig, # we need to add this check to not get an attribute error activation_scale_ub = None + weight_block_size = quantizer_config.weight_block_size if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + return HybridFP8UnquantLoader( + activation_scale_ub, + to_fp8=quantize == "fp8", + weight_block_size=weight_block_size, + ) else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index c03dd2b0..e1c9ab33 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -149,6 +149,10 @@ class Weights: ): routing = {} for filename in filenames: + # if filename.as_posix().endswith("l.safetensors"): + # from loguru import logger + # logger.info(f"Skipping {filename}") + # continue with safe_open(filename, framework="pytorch") as f: for k in f.keys(): if k in routing: