mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
enable deepseek_r1
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
329f612e55
commit
debf477ba4
@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
|
|||||||
|
|
||||||
from vllm_hpu_extension.ops import scaled_fp8_quant
|
from vllm_hpu_extension.ops import scaled_fp8_quant
|
||||||
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
|
||||||
import habana_frameworks.torch.utils.experimental as htexp
|
|
||||||
|
|
||||||
w8a8_block_fp8_matmul = None
|
|
||||||
per_token_group_quant_fp8 = None
|
|
||||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||||
|
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
|
||||||
|
|
||||||
|
|
||||||
|
def pad_weight(weight, block_size):
|
||||||
|
"""Pads a matrix to make its dimensions multiples of block_size."""
|
||||||
|
M, N = weight.shape[-2:]
|
||||||
|
block_size_m, block_size_n = block_size
|
||||||
|
pad_M = (block_size_m - M % block_size_m) % block_size_m
|
||||||
|
pad_N = (block_size_n - N % block_size_n) % block_size_n
|
||||||
|
|
||||||
|
if pad_M == 0 and pad_N == 0:
|
||||||
|
return weight, M, N # No padding needed
|
||||||
|
padded_weight = torch.nn.functional.pad(
|
||||||
|
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
|
||||||
|
)
|
||||||
|
return padded_weight, M, N # Return original dimensions for unpadding
|
||||||
|
|
||||||
|
|
||||||
|
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
|
||||||
|
"""Removes padding from the matrix to restore its original shape."""
|
||||||
|
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
|
||||||
|
return weight
|
||||||
|
if keep_first_dim:
|
||||||
|
return weight[:, :original_M, :original_N]
|
||||||
|
else:
|
||||||
|
return weight[:original_M, :original_N]
|
||||||
|
|
||||||
|
|
||||||
|
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
|
||||||
|
|
||||||
|
assert len(block_size) == 2
|
||||||
|
|
||||||
|
block_size_m, block_size_n = block_size
|
||||||
|
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
|
||||||
|
|
||||||
|
weight, orig_M, orig_N = pad_weight(weight, block_size)
|
||||||
|
M, N = weight.shape[-2:]
|
||||||
|
|
||||||
|
assert weight_scale_m == M // block_size_m
|
||||||
|
assert weight_scale_n == N // block_size_n
|
||||||
|
|
||||||
|
return weight, orig_M, orig_N
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_quant(data, single_scale=False):
|
||||||
|
if single_scale:
|
||||||
|
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
|
||||||
|
else:
|
||||||
|
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
|
||||||
|
scale = scale.unsqueeze(-1)
|
||||||
|
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
|
data, 1.0 / scale, False, False, torch.float8_e4m3fn
|
||||||
|
)[0]
|
||||||
|
return data_fp8, scale.float()
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_block_fp8_weight_naive(
|
||||||
|
weight,
|
||||||
|
weight_scale,
|
||||||
|
block_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
original_M=None,
|
||||||
|
original_N=None,
|
||||||
|
do_unpad=False,
|
||||||
|
):
|
||||||
|
if weight_scale is None:
|
||||||
|
return weight
|
||||||
|
assert len(block_size) == 2
|
||||||
|
|
||||||
|
weight_shape_len = len(weight.shape)
|
||||||
|
|
||||||
|
block_size_m, block_size_n = block_size
|
||||||
|
|
||||||
|
# mul scale
|
||||||
|
if weight_shape_len == 2:
|
||||||
|
weight_scale_m, weight_scale_n = weight_scale.shape
|
||||||
|
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
|
||||||
|
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||||
|
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||||
|
else:
|
||||||
|
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||||
|
dequant_weight = dequant_weight.view(
|
||||||
|
weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||||
|
)
|
||||||
|
keep_first_dim = False
|
||||||
|
elif weight_shape_len == 3:
|
||||||
|
fd, weight_scale_m, weight_scale_n = weight_scale.shape
|
||||||
|
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
|
||||||
|
weight = weight.view(
|
||||||
|
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
|
||||||
|
)
|
||||||
|
if is_hpu_gaudi2():
|
||||||
|
fake_weight = weight.cpu().to(dtype).to(weight.device)
|
||||||
|
dequant_weight = fake_weight * weight_scale.to(dtype)
|
||||||
|
else:
|
||||||
|
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
|
||||||
|
dequant_weight = dequant_weight.view(
|
||||||
|
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
|
||||||
|
)
|
||||||
|
keep_first_dim = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Only support original weight shape is either 2 or 3")
|
||||||
|
|
||||||
|
if do_unpad:
|
||||||
|
dequant_weight = unpad_weight(
|
||||||
|
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
return dequant_weight
|
||||||
|
|
||||||
|
|
||||||
|
def apply_block_fp8_linear_hpu_dynamic(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
|
||||||
|
x_fp8, x_scale = dynamic_quant(input_2d)
|
||||||
|
|
||||||
|
output = torch.ops.hpu.fp8_gemm_v2(
|
||||||
|
x_fp8,
|
||||||
|
False,
|
||||||
|
weight,
|
||||||
|
True,
|
||||||
|
None,
|
||||||
|
torch.bfloat16,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||||
@ -42,7 +182,7 @@ def per_tensor_dequantize(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
device = tensor.device
|
device = tensor.device
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
|
if is_hpu_gaudi2():
|
||||||
# dequant on cpu to avoid nan on gaudi2
|
# dequant on cpu to avoid nan on gaudi2
|
||||||
tensor = tensor.to("cpu")
|
tensor = tensor.to("cpu")
|
||||||
|
|
||||||
@ -389,6 +529,22 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
weight_block_size = kwargs.get("weight_block_size", None)
|
weight_block_size = kwargs.get("weight_block_size", None)
|
||||||
|
|
||||||
|
if weight_block_size is not None:
|
||||||
|
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
|
||||||
|
weight, scale, weight_block_size
|
||||||
|
)
|
||||||
|
weight, scale = dynamic_quant(
|
||||||
|
dequant_block_fp8_weight_naive(
|
||||||
|
weight,
|
||||||
|
scale,
|
||||||
|
weight_block_size,
|
||||||
|
original_M=orig_M,
|
||||||
|
original_N=orig_N,
|
||||||
|
do_unpad=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scale = scale.squeeze(-1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@ -409,25 +565,10 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.weight_block_size is not None:
|
if self.weight_block_size is not None:
|
||||||
# https://arxiv.org/pdf/2412.19437
|
return apply_block_fp8_linear_hpu_dynamic(
|
||||||
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
input, self.qweight, self.scale, self.input_scale, self.bias
|
||||||
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
|
||||||
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
|
||||||
# channels).
|
|
||||||
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
|
||||||
output = w8a8_block_fp8_matmul(
|
|
||||||
qinput,
|
|
||||||
self.qweight,
|
|
||||||
scale,
|
|
||||||
self.scale,
|
|
||||||
self.weight_block_size,
|
|
||||||
output_dtype=input.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
output = output + self.bias
|
|
||||||
return output.to(dtype=input.dtype)
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input,
|
input,
|
||||||
self.input_scale,
|
self.input_scale,
|
||||||
|
@ -4,7 +4,12 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import (
|
||||||
|
Weight,
|
||||||
|
Weights,
|
||||||
|
WeightsLoader,
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from .hpu import QuantLinear
|
from .hpu import QuantLinear
|
||||||
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
quant_method: str,
|
quant_method: str,
|
||||||
quantize: str,
|
quantize: str,
|
||||||
sym: bool,
|
sym: bool,
|
||||||
|
modules_to_not_convert: List[str],
|
||||||
):
|
):
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.desc_act = desc_act
|
self.desc_act = desc_act
|
||||||
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.sym = sym
|
self.sym = sym
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
|
def is_layer_skipped_quantization(
|
||||||
|
self, prefix: str, modules_to_not_convert: List[str]
|
||||||
|
):
|
||||||
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
|
||||||
def get_weights(self, weights: Weights, prefix: str):
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights(weights, prefix)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_tensor(f"{prefix}.qweight")
|
qweight = weights.get_tensor(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights_col_packed(
|
||||||
|
weights, prefix, block_sizes
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
qweight = weights.get_packed_sharded(
|
qweight = weights.get_packed_sharded(
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||||
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
|
|||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
|
||||||
|
return DefaultWeightsLoader.get_weights_row(weights, prefix)
|
||||||
|
|
||||||
if self.desc_act:
|
if self.desc_act:
|
||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
@ -9,12 +9,11 @@ from text_generation_server.layers.fp8 import (
|
|||||||
fp8_quantize,
|
fp8_quantize,
|
||||||
quant_dtype,
|
quant_dtype,
|
||||||
normalize_e4m3fn_to_native_float8,
|
normalize_e4m3fn_to_native_float8,
|
||||||
|
dynamic_quant,
|
||||||
|
dequant_block_fp8_weight_naive,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.moe.fused_moe import select_experts
|
||||||
try:
|
import habana_frameworks.torch as htorch
|
||||||
from .unquantized import fused_moe
|
|
||||||
except Exception:
|
|
||||||
fused_moe = None
|
|
||||||
|
|
||||||
|
|
||||||
class FP8SparseMoELayer(nn.Module):
|
class FP8SparseMoELayer(nn.Module):
|
||||||
@ -68,27 +67,78 @@ class FP8SparseMoELayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if self.weight_block_size is not None:
|
||||||
|
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
|
||||||
|
dequant_block_fp8_weight_naive(
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.gate_up_proj_weight_scale,
|
||||||
|
self.weight_block_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
|
||||||
|
dequant_block_fp8_weight_naive(
|
||||||
|
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
|
||||||
|
self.gate_up_proj_weight_scale.squeeze(-1),
|
||||||
|
self.down_proj_weight_scale.squeeze(-1),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
return fused_moe(
|
topk_weights, topk_ids = select_experts(
|
||||||
x,
|
hidden_states=x,
|
||||||
w1=self.gate_up_proj,
|
router_logits=gating_output,
|
||||||
w2=self.down_proj,
|
|
||||||
gating_output=gating_output,
|
|
||||||
topk=self.topk,
|
|
||||||
renormalize=self.renormalize,
|
|
||||||
inplace=True,
|
|
||||||
use_grouped_topk=self.n_expert_group is not None,
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
num_expert_group=self.n_expert_group,
|
top_k=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
scoring_func=self.scoring_func,
|
scoring_func=self.scoring_func,
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=self.gate_up_proj_weight_scale,
|
|
||||||
w2_scale=self.down_proj_weight_scale,
|
|
||||||
a1_scale=self.gate_up_proj_input_scale,
|
|
||||||
a2_scale=self.down_proj_input_scale,
|
|
||||||
)
|
)
|
||||||
|
total_num_experts = gating_output.size(-1)
|
||||||
|
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
|
||||||
|
moe_n_slice = (total_num_experts + 31) // 32
|
||||||
|
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
|
||||||
|
for i in range(moe_n_slice):
|
||||||
|
min_expert = i * n_expert_slice
|
||||||
|
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
|
||||||
|
w13_list_slice = [
|
||||||
|
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
|
||||||
|
]
|
||||||
|
w2_list_slice = [
|
||||||
|
self.down_proj[j, ...] for j in range(min_expert, max_expert)
|
||||||
|
]
|
||||||
|
w13_weight_scale = [
|
||||||
|
self.gate_up_proj_weight_scale[j, ...]
|
||||||
|
for j in range(min_expert, max_expert)
|
||||||
|
]
|
||||||
|
w2_weight_scale = [
|
||||||
|
self.down_proj_weight_scale[j, ...]
|
||||||
|
for j in range(min_expert, max_expert)
|
||||||
|
]
|
||||||
|
|
||||||
|
current_hidden_states = torch.ops.hpu.mixture_of_experts(
|
||||||
|
hidden_states=x_fp8,
|
||||||
|
expert_routing_table=topk_ids.to(torch.int64),
|
||||||
|
router_weights=topk_weights.to(x.dtype),
|
||||||
|
w12=w13_list_slice,
|
||||||
|
w3=w2_list_slice,
|
||||||
|
d_scale_hidden_states=x_scale,
|
||||||
|
d_scale_w12=w13_weight_scale,
|
||||||
|
d_scale_w3=w2_weight_scale,
|
||||||
|
permuted_weights=True,
|
||||||
|
activation="silu",
|
||||||
|
experts_min=min_expert,
|
||||||
|
experts_max=max_expert - 1,
|
||||||
|
)
|
||||||
|
htorch.core.mark_step()
|
||||||
|
if i == 0:
|
||||||
|
final_hidden_states = current_hidden_states
|
||||||
|
else:
|
||||||
|
final_hidden_states.add_(current_hidden_states)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def _load_expert_weights(
|
def _load_expert_weights(
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -25,12 +25,36 @@ def grouped_topk(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
gating_output = gating_output.float()
|
||||||
|
if e_score_correction_bias is not None:
|
||||||
|
e_score_correction_bias = e_score_correction_bias.float()
|
||||||
|
|
||||||
|
if scoring_func == "softmax":
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
elif scoring_func == "sigmoid":
|
||||||
|
scores = gating_output.sigmoid()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||||
|
|
||||||
num_token = scores.shape[0]
|
num_token = scores.shape[0]
|
||||||
group_scores = (
|
if e_score_correction_bias is not None:
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
# Store original scores before applying correction bias. We use biased
|
||||||
) # [n, n_group]
|
# scores for expert selection but original scores for routing weights
|
||||||
|
original_scores = scores
|
||||||
|
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
1
|
1
|
||||||
] # [n, top_k_group]
|
] # [n, top_k_group]
|
||||||
@ -41,13 +65,19 @@ def grouped_topk(
|
|||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
.reshape(num_token, -1)
|
.reshape(num_token, -1)
|
||||||
) # [n, e]
|
) # [n, e]
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
if e_score_correction_bias is not None:
|
||||||
|
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||||
|
# Use original unbiased scores for the routing weights
|
||||||
|
topk_weights = original_scores.gather(1, topk_ids)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
@ -63,3 +93,39 @@ def fused_topk(
|
|||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def select_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
# DeekSeekv2 uses grouped_top_k
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
mscale_all_dim: float,
|
mscale_all_dim: float,
|
||||||
):
|
):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(
|
|
||||||
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
|
||||||
)
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
|||||||
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
/ get_mscale(self.scaling_factor, mscale_all_dim)
|
||||||
* self.attn_factor
|
* self.attn_factor
|
||||||
) # Get n-d magnitude scaling corrected for interpolation
|
) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
@ -18,6 +18,8 @@ class _QuantizerConfig:
|
|||||||
groupsize: int
|
groupsize: int
|
||||||
quant_method: str
|
quant_method: str
|
||||||
sym: bool
|
sym: bool
|
||||||
|
weight_block_size: Optional[List[int]]
|
||||||
|
modules_to_not_convert: List[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
|
|||||||
activation_scale_ub: float
|
activation_scale_ub: float
|
||||||
|
|
||||||
|
|
||||||
# We should probably do this with Pytantic JSON deserialization,
|
def _get_config_json(model_id: str, revision: Optional[str], filename: str):
|
||||||
|
if os.path.exists(
|
||||||
|
os.path.join(
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
filename = os.path.join(model_id, filename)
|
||||||
|
else:
|
||||||
|
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# We should probably do this with Pydantic JSON deserialization,
|
||||||
# but for now we'll stay close to the old _set_gptq_params.
|
# but for now we'll stay close to the old _set_gptq_params.
|
||||||
def _get_quantizer_config(model_id, revision):
|
def _get_quantizer_config(model_id, revision):
|
||||||
bits = 4
|
bits = 4
|
||||||
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format = None
|
checkpoint_format = None
|
||||||
sym = False
|
sym = False
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
weight_block_size = None
|
||||||
|
modules_to_not_convert = []
|
||||||
|
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
data = _get_config_json(model_id, revision, filename)
|
||||||
filename = os.path.join(model_id, filename)
|
|
||||||
else:
|
|
||||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# FP8 config
|
# FP8 config
|
||||||
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
|
||||||
return _FP8QuantizerConfig(
|
return _FP8QuantizerConfig(
|
||||||
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
|
||||||
)
|
)
|
||||||
|
weight_block_size = data["quantization_config"].get("weight_block_size", None)
|
||||||
|
|
||||||
if "zero_point" in data["quantization_config"]:
|
if "zero_point" in data["quantization_config"]:
|
||||||
sym = not data["quantization_config"]["zero_point"]
|
sym = not data["quantization_config"]["zero_point"]
|
||||||
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
# Order is important here, desc_act is missing on some real models
|
# Order is important here, desc_act is missing on some real models
|
||||||
quant_method = data["quantization_config"]["quant_method"]
|
quant_method = data["quantization_config"]["quant_method"]
|
||||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||||
desc_act = data["quantization_config"]["desc_act"]
|
desc_act = data["quantization_config"].get("desc_act", False)
|
||||||
|
modules_to_not_convert = data["quantization_config"].get(
|
||||||
|
"modules_to_not_convert", []
|
||||||
|
)
|
||||||
|
if modules_to_not_convert is None:
|
||||||
|
modules_to_not_convert = []
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
data = _get_config_json(model_id, revision, filename)
|
||||||
filename = os.path.join(model_id, filename)
|
|
||||||
else:
|
|
||||||
filename = hf_hub_download(
|
|
||||||
model_id, filename=filename, revision=revision
|
|
||||||
)
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
bits = data["bits"]
|
bits = data["bits"]
|
||||||
groupsize = data["group_size"]
|
groupsize = data["group_size"]
|
||||||
|
|
||||||
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(model_id, filename)):
|
data = _get_config_json(model_id, revision, filename)
|
||||||
filename = os.path.join(model_id, filename)
|
|
||||||
else:
|
|
||||||
filename = hf_hub_download(
|
|
||||||
model_id, filename=filename, revision=revision
|
|
||||||
)
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
bits = data["w_bit"]
|
bits = data["w_bit"]
|
||||||
groupsize = data["q_group_size"]
|
groupsize = data["q_group_size"]
|
||||||
desc_act = data["desc_act"]
|
desc_act = data["desc_act"]
|
||||||
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
|
|||||||
checkpoint_format=checkpoint_format,
|
checkpoint_format=checkpoint_format,
|
||||||
sym=sym,
|
sym=sym,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -134,6 +139,7 @@ def get_loader(
|
|||||||
quant_method=quantizer_config.quant_method,
|
quant_method=quantizer_config.quant_method,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
sym=quantizer_config.sym,
|
sym=quantizer_config.sym,
|
||||||
|
modules_to_not_convert=quantizer_config.modules_to_not_convert,
|
||||||
)
|
)
|
||||||
elif quantize == "fp8" or quantize is None:
|
elif quantize == "fp8" or quantize is None:
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
@ -141,9 +147,14 @@ def get_loader(
|
|||||||
# Since the default for the quantize config is _QuantizerConfig,
|
# Since the default for the quantize config is _QuantizerConfig,
|
||||||
# we need to add this check to not get an attribute error
|
# we need to add this check to not get an attribute error
|
||||||
activation_scale_ub = None
|
activation_scale_ub = None
|
||||||
|
weight_block_size = quantizer_config.weight_block_size
|
||||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
return HybridFP8UnquantLoader(
|
||||||
|
activation_scale_ub,
|
||||||
|
to_fp8=quantize == "fp8",
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
Loading…
Reference in New Issue
Block a user