enable deepseek_r1

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-28 23:07:34 -07:00
parent 329f612e55
commit debf477ba4
6 changed files with 369 additions and 79 deletions

View File

@ -12,11 +12,151 @@ from text_generation_server.utils.weights import (
from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
@ -42,7 +182,7 @@ def per_tensor_dequantize(
) -> torch.Tensor: ) -> torch.Tensor:
device = tensor.device device = tensor.device
dtype = torch.bfloat16 dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2 # dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu") tensor = tensor.to("cpu")
@ -389,6 +529,22 @@ class Fp8Linear(torch.nn.Module):
scale_upper_bound = kwargs.get("scale_upper_bound", None) scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None) weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
@ -409,25 +565,10 @@ class Fp8Linear(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None: if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437 return apply_block_fp8_linear_hpu_dynamic(
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and input, self.qweight, self.scale, self.input_scale, self.bias
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput,
self.qweight,
scale,
self.scale,
self.weight_block_size,
output_dtype=input.dtype,
) )
if self.bias is not None:
output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, input,
self.input_scale, self.input_scale,

View File

@ -4,7 +4,12 @@ from typing import List, Optional, Union
import torch import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear from .hpu import QuantLinear
@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader):
quant_method: str, quant_method: str,
quantize: str, quantize: str,
sym: bool, sym: bool,
modules_to_not_convert: List[str],
): ):
self.bits = bits self.bits = bits
self.desc_act = desc_act self.desc_act = desc_act
@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader):
self.quant_method = quant_method self.quant_method = quant_method
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str): def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights) self._get_gptq_params(weights)
@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader):
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try: try:
qweight = weights.get_tensor(f"{prefix}.qweight") qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader):
prefix: str, prefix: str,
block_sizes: Union[int, List[int]], block_sizes: Union[int, List[int]],
): ):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try: try:
qweight = weights.get_packed_sharded( qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes f"{prefix}.qweight", dim=1, block_sizes=block_sizes
@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try: try:
qweight = torch.cat( qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader):
if self.bits != 4: if self.bits != 4:
use_exllama = False use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act: if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True") log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False use_exllama = False

View File

@ -9,12 +9,11 @@ from text_generation_server.layers.fp8 import (
fp8_quantize, fp8_quantize,
quant_dtype, quant_dtype,
normalize_e4m3fn_to_native_float8, normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
) )
from text_generation_server.layers.moe.fused_moe import select_experts
try: import habana_frameworks.torch as htorch
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module): class FP8SparseMoELayer(nn.Module):
@ -68,27 +67,78 @@ class FP8SparseMoELayer(nn.Module):
weights=weights, weights=weights,
) )
) )
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe( topk_weights, topk_ids = select_experts(
x, hidden_states=x,
w1=self.gate_up_proj, router_logits=gating_output,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None, use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group, top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
) )
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
moe_n_slice = (total_num_experts + 31) // 32
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert,
experts_max=max_expert - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights( def _load_expert_weights(

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Tuple, Optional
import torch import torch
@ -25,12 +25,36 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1) scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0] num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = ( group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group] ) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1 1
] # [n, top_k_group] ] # [n, top_k_group]
@ -41,13 +65,19 @@ def grouped_topk(
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1) .reshape(num_token, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk( def fused_topk(
@ -63,3 +93,39 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float, mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
/ get_mscale(self.scaling_factor, mscale_all_dim) / get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, List
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
@ -18,6 +18,8 @@ class _QuantizerConfig:
groupsize: int groupsize: int
quant_method: str quant_method: str
sym: bool sym: bool
weight_block_size: Optional[List[int]]
modules_to_not_convert: List[str]
@dataclass @dataclass
@ -25,7 +27,20 @@ class _FP8QuantizerConfig:
activation_scale_ub: float activation_scale_ub: float
# We should probably do this with Pytantic JSON deserialization, def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params. # but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision): def _get_quantizer_config(model_id, revision):
bits = 4 bits = 4
@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format = None checkpoint_format = None
sym = False sym = False
desc_act = False desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json" filename = "config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
# FP8 config # FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8": if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig( return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]: if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"] sym = not data["quantization_config"]["zero_point"]
@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision):
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision):
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
if os.path.exists(os.path.join(model_id, filename)): data = _get_config_json(model_id, revision, filename)
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"] bits = data["w_bit"]
groupsize = data["q_group_size"] groupsize = data["q_group_size"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision):
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
sym=sym, sym=sym,
desc_act=desc_act, desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
) )
@ -134,6 +139,7 @@ def get_loader(
quant_method=quantizer_config.quant_method, quant_method=quantizer_config.quant_method,
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
) )
elif quantize == "fp8" or quantize is None: elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
@ -141,9 +147,14 @@ def get_loader(
# Since the default for the quantize config is _QuantizerConfig, # Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error # we need to add this check to not get an attribute error
activation_scale_ub = None activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig): if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else: else:
raise ValueError(f"Unknown quantization method: {quantize}") raise ValueError(f"Unknown quantization method: {quantize}")