diff --git a/backends/gaudi/server/text_generation_server/layers/fp8.py b/backends/gaudi/server/text_generation_server/layers/fp8.py index 0dc5cdaf..78f03511 100644 --- a/backends/gaudi/server/text_generation_server/layers/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/fp8.py @@ -12,11 +12,151 @@ from text_generation_server.utils.weights import ( from vllm_hpu_extension.ops import scaled_fp8_quant from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2 -import habana_frameworks.torch.utils.experimental as htexp -w8a8_block_fp8_matmul = None -per_token_group_quant_fp8 = None quant_dtype: torch.dtype = torch.float8_e4m3fn +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max +if is_hpu_gaudi2(): + FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max + + +def pad_weight(weight, block_size): + """Pads a matrix to make its dimensions multiples of block_size.""" + M, N = weight.shape[-2:] + block_size_m, block_size_n = block_size + pad_M = (block_size_m - M % block_size_m) % block_size_m + pad_N = (block_size_n - N % block_size_n) % block_size_n + + if pad_M == 0 and pad_N == 0: + return weight, M, N # No padding needed + padded_weight = torch.nn.functional.pad( + weight, (0, pad_N, 0, pad_M), mode="constant", value=0 + ) + return padded_weight, M, N # Return original dimensions for unpadding + + +def unpad_weight(weight, original_M, original_N, keep_first_dim=False): + """Removes padding from the matrix to restore its original shape.""" + if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N): + return weight + if keep_first_dim: + return weight[:, :original_M, :original_N] + else: + return weight[:original_M, :original_N] + + +def pad_block_fp8_weight_naive(weight, weight_scale, block_size): + + assert len(block_size) == 2 + + block_size_m, block_size_n = block_size + weight_scale_m, weight_scale_n = weight_scale.shape[-2:] + + weight, orig_M, orig_N = pad_weight(weight, block_size) + M, N = weight.shape[-2:] + + assert weight_scale_m == M // block_size_m + assert weight_scale_n == N // block_size_n + + return weight, orig_M, orig_N + + +def dynamic_quant(data, single_scale=False): + if single_scale: + scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX + else: + scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX + scale = scale.unsqueeze(-1) + data_fp8 = torch.ops.hpu.cast_to_fp8_v2( + data, 1.0 / scale, False, False, torch.float8_e4m3fn + )[0] + return data_fp8, scale.float() + + +def dequant_block_fp8_weight_naive( + weight, + weight_scale, + block_size, + dtype=torch.bfloat16, + original_M=None, + original_N=None, + do_unpad=False, +): + if weight_scale is None: + return weight + assert len(block_size) == 2 + + weight_shape_len = len(weight.shape) + + block_size_m, block_size_n = block_size + + # mul scale + if weight_shape_len == 2: + weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = False + elif weight_shape_len == 3: + fd, weight_scale_m, weight_scale_n = weight_scale.shape + weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1) + weight = weight.view( + fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n + ) + if is_hpu_gaudi2(): + fake_weight = weight.cpu().to(dtype).to(weight.device) + dequant_weight = fake_weight * weight_scale.to(dtype) + else: + dequant_weight = weight.to(dtype) * weight_scale.to(dtype) + dequant_weight = dequant_weight.view( + fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n + ) + keep_first_dim = True + else: + raise ValueError("Only support original weight shape is either 2 or 3") + + if do_unpad: + dequant_weight = unpad_weight( + dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim + ) + + return dequant_weight + + +def apply_block_fp8_linear_hpu_dynamic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + x_fp8, x_scale = dynamic_quant(input_2d) + + output = torch.ops.hpu.fp8_gemm_v2( + x_fp8, + False, + weight, + True, + None, + torch.bfloat16, + x_scale, + weight_scale, + None, + False, + ) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: @@ -42,7 +182,7 @@ def per_tensor_dequantize( ) -> torch.Tensor: device = tensor.device dtype = torch.bfloat16 - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + if is_hpu_gaudi2(): # dequant on cpu to avoid nan on gaudi2 tensor = tensor.to("cpu") @@ -389,6 +529,22 @@ class Fp8Linear(torch.nn.Module): scale_upper_bound = kwargs.get("scale_upper_bound", None) weight_block_size = kwargs.get("weight_block_size", None) + if weight_block_size is not None: + weight, orig_M, orig_N = pad_block_fp8_weight_naive( + weight, scale, weight_block_size + ) + weight, scale = dynamic_quant( + dequant_block_fp8_weight_naive( + weight, + scale, + weight_block_size, + original_M=orig_M, + original_N=orig_N, + do_unpad=True, + ) + ) + scale = scale.squeeze(-1) + return cls( qweight=weight, scale=scale, @@ -409,25 +565,10 @@ class Fp8Linear(torch.nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight_block_size is not None: - # https://arxiv.org/pdf/2412.19437 - # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and - # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we - # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output - # channels). - qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1]) - output = w8a8_block_fp8_matmul( - qinput, - self.qweight, - scale, - self.scale, - self.weight_block_size, - output_dtype=input.dtype, + return apply_block_fp8_linear_hpu_dynamic( + input, self.qweight, self.scale, self.input_scale, self.bias ) - if self.bias is not None: - output = output + self.bias - return output.to(dtype=input.dtype) - qinput, scale = fp8_quantize( input, self.input_scale, diff --git a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py index 90b8f692..babf3d4b 100644 --- a/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/gptq/__init__.py @@ -4,7 +4,12 @@ from typing import List, Optional, Union import torch from loguru import logger from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +from text_generation_server.utils.weights import ( + Weight, + Weights, + WeightsLoader, + DefaultWeightsLoader, +) from .hpu import QuantLinear @@ -72,6 +77,7 @@ class GPTQWeightsLoader(WeightsLoader): quant_method: str, quantize: str, sym: bool, + modules_to_not_convert: List[str], ): self.bits = bits self.desc_act = desc_act @@ -79,6 +85,12 @@ class GPTQWeightsLoader(WeightsLoader): self.quant_method = quant_method self.quantize = quantize self.sym = sym + self.modules_to_not_convert = modules_to_not_convert + + def is_layer_skipped_quantization( + self, prefix: str, modules_to_not_convert: List[str] + ): + return any(module_name in prefix for module_name in modules_to_not_convert) def get_weights(self, weights: Weights, prefix: str): self._get_gptq_params(weights) @@ -91,6 +103,9 @@ class GPTQWeightsLoader(WeightsLoader): log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights(weights, prefix) + try: qweight = weights.get_tensor(f"{prefix}.qweight") except RuntimeError: @@ -145,6 +160,10 @@ class GPTQWeightsLoader(WeightsLoader): prefix: str, block_sizes: Union[int, List[int]], ): + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_col_packed( + weights, prefix, block_sizes + ) try: qweight = weights.get_packed_sharded( f"{prefix}.qweight", dim=1, block_sizes=block_sizes @@ -196,6 +215,8 @@ class GPTQWeightsLoader(WeightsLoader): ) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert): + return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim) try: qweight = torch.cat( [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 @@ -263,6 +284,9 @@ class GPTQWeightsLoader(WeightsLoader): if self.bits != 4: use_exllama = False + if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert): + return DefaultWeightsLoader.get_weights_row(weights, prefix) + if self.desc_act: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py index 071b2abe..5362e8de 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fp8.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fp8.py @@ -9,12 +9,11 @@ from text_generation_server.layers.fp8 import ( fp8_quantize, quant_dtype, normalize_e4m3fn_to_native_float8, + dynamic_quant, + dequant_block_fp8_weight_naive, ) - -try: - from .unquantized import fused_moe -except Exception: - fused_moe = None +from text_generation_server.layers.moe.fused_moe import select_experts +import habana_frameworks.torch as htorch class FP8SparseMoELayer(nn.Module): @@ -68,27 +67,78 @@ class FP8SparseMoELayer(nn.Module): weights=weights, ) ) + if self.weight_block_size is not None: + self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.weight_block_size, + ) + ) + self.down_proj, self.down_proj_weight_scale = dynamic_quant( + dequant_block_fp8_weight_naive( + self.down_proj, self.down_proj_weight_scale, self.weight_block_size + ) + ) + self.gate_up_proj_weight_scale, self.down_proj_weight_scale = ( + self.gate_up_proj_weight_scale.squeeze(-1), + self.down_proj_weight_scale.squeeze(-1), + ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - return fused_moe( - x, - w1=self.gate_up_proj, - w2=self.down_proj, - gating_output=gating_output, - topk=self.topk, - renormalize=self.renormalize, - inplace=True, + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=gating_output, use_grouped_topk=self.n_expert_group is not None, - num_expert_group=self.n_expert_group, + top_k=self.topk, + renormalize=self.renormalize, topk_group=self.topk_group, + num_expert_group=self.n_expert_group, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, - use_fp8_w8a8=True, - w1_scale=self.gate_up_proj_weight_scale, - w2_scale=self.down_proj_weight_scale, - a1_scale=self.gate_up_proj_input_scale, - a2_scale=self.down_proj_input_scale, ) + total_num_experts = gating_output.size(-1) + x_fp8, x_scale = dynamic_quant(x, single_scale=True) + moe_n_slice = (total_num_experts + 31) // 32 + n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice + for i in range(moe_n_slice): + min_expert = i * n_expert_slice + max_expert = min((i + 1) * n_expert_slice, total_num_experts) + w13_list_slice = [ + self.gate_up_proj[j, ...] for j in range(min_expert, max_expert) + ] + w2_list_slice = [ + self.down_proj[j, ...] for j in range(min_expert, max_expert) + ] + w13_weight_scale = [ + self.gate_up_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + w2_weight_scale = [ + self.down_proj_weight_scale[j, ...] + for j in range(min_expert, max_expert) + ] + + current_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=x_fp8, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + d_scale_hidden_states=x_scale, + d_scale_w12=w13_weight_scale, + d_scale_w3=w2_weight_scale, + permuted_weights=True, + activation="silu", + experts_min=min_expert, + experts_max=max_expert - 1, + ) + htorch.core.mark_step() + if i == 0: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + return final_hidden_states def _load_expert_weights( diff --git a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py index e26ff877..1987f0ed 100644 --- a/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py +++ b/backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch @@ -25,12 +25,36 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - scores = torch.softmax(gating_output, dim=-1) + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] @@ -41,13 +65,19 @@ def grouped_topk( .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def fused_topk( @@ -63,3 +93,39 @@ def fused_topk( if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +): + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 6a83d6a5..7e740e5f 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -470,9 +470,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) - super().__init__( - inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor - ) self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -487,6 +484,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): / get_mscale(self.scaling_factor, mscale_all_dim) * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation + super().__init__(inv_freq, scaling_factor, max_position_embeddings) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, diff --git a/backends/gaudi/server/text_generation_server/utils/quantization.py b/backends/gaudi/server/text_generation_server/utils/quantization.py index a8faf4a5..022a4897 100644 --- a/backends/gaudi/server/text_generation_server/utils/quantization.py +++ b/backends/gaudi/server/text_generation_server/utils/quantization.py @@ -1,7 +1,7 @@ import json import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, List from huggingface_hub import hf_hub_download from text_generation_server.utils.weights import ( @@ -18,6 +18,8 @@ class _QuantizerConfig: groupsize: int quant_method: str sym: bool + weight_block_size: Optional[List[int]] + modules_to_not_convert: List[str] @dataclass @@ -25,7 +27,20 @@ class _FP8QuantizerConfig: activation_scale_ub: float -# We should probably do this with Pytantic JSON deserialization, +def _get_config_json(model_id: str, revision: Optional[str], filename: str): + if os.path.exists( + os.path.join( + model_id, + ) + ): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + return json.load(f) + + +# We should probably do this with Pydantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): bits = 4 @@ -34,21 +49,18 @@ def _get_quantizer_config(model_id, revision): checkpoint_format = None sym = False desc_act = False + weight_block_size = None + modules_to_not_convert = [] filename = "config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download(model_id, filename=filename, revision=revision) - with open(filename, "r") as f: - data = json.load(f) - + data = _get_config_json(model_id, revision, filename) # FP8 config if data["quantization_config"]["quant_method"] == "fbgemm_fp8": return _FP8QuantizerConfig( activation_scale_ub=data["quantization_config"]["activation_scale_ub"] ) + weight_block_size = data["quantization_config"].get("weight_block_size", None) if "zero_point" in data["quantization_config"]: sym = not data["quantization_config"]["zero_point"] @@ -61,18 +73,16 @@ def _get_quantizer_config(model_id, revision): # Order is important here, desc_act is missing on some real models quant_method = data["quantization_config"]["quant_method"] checkpoint_format = data["quantization_config"].get("checkpoint_format") - desc_act = data["quantization_config"]["desc_act"] + desc_act = data["quantization_config"].get("desc_act", False) + modules_to_not_convert = data["quantization_config"].get( + "modules_to_not_convert", [] + ) + if modules_to_not_convert is None: + modules_to_not_convert = [] except Exception: filename = "quantize_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["bits"] groupsize = data["group_size"] @@ -88,14 +98,7 @@ def _get_quantizer_config(model_id, revision): except Exception: filename = "quant_config.json" try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) + data = _get_config_json(model_id, revision, filename) bits = data["w_bit"] groupsize = data["q_group_size"] desc_act = data["desc_act"] @@ -111,6 +114,8 @@ def _get_quantizer_config(model_id, revision): checkpoint_format=checkpoint_format, sym=sym, desc_act=desc_act, + weight_block_size=weight_block_size, + modules_to_not_convert=modules_to_not_convert, ) @@ -134,6 +139,7 @@ def get_loader( quant_method=quantizer_config.quant_method, quantize=quantize, sym=quantizer_config.sym, + modules_to_not_convert=quantizer_config.modules_to_not_convert, ) elif quantize == "fp8" or quantize is None: from text_generation_server.layers.fp8 import HybridFP8UnquantLoader @@ -141,9 +147,14 @@ def get_loader( # Since the default for the quantize config is _QuantizerConfig, # we need to add this check to not get an attribute error activation_scale_ub = None + weight_block_size = quantizer_config.weight_block_size if isinstance(quantizer_config, _FP8QuantizerConfig): activation_scale_ub = quantizer_config.activation_scale_ub - return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") + return HybridFP8UnquantLoader( + activation_scale_ub, + to_fp8=quantize == "fp8", + weight_block_size=weight_block_size, + ) else: raise ValueError(f"Unknown quantization method: {quantize}")