diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5deff5ca..5b5cb45e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -19,7 +19,6 @@ mod env_runtime; enum Quantization { Bitsandbytes, Gptq, - Gptq_cuda, } impl std::fmt::Display for Quantization { @@ -32,9 +31,6 @@ impl std::fmt::Display for Quantization { Quantization::Gptq => { write!(f, "gptq") } - Quantization::Gptq_cuda => { - write!(f, "gptq-cuda") - } } } } diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index e5c47467..3463049a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -14,7 +14,6 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" - gptq_cuda = "gptq-cuda" class Dtype(str, Enum): diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a22c44dd..fd97f8b1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -286,7 +286,7 @@ def get_model( if sharded: raise ValueError("sharded is not supported for AutoModel") - if quantize in ["gptq", "gptq-cuda"]: + if quantize == "gptq": raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9283a838..dfc42a5e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -27,7 +27,7 @@ from custom_kernels.exllama import prepare_buffers, set_tuning_params def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): - if config.quantize in ["gptq", "gptq-cuda"]: + if config.quantize == "gptq": layer = _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) @@ -79,7 +79,10 @@ def _load_multi_mqa_gptq( bits = weights.get_tensor("gptq_bits").item() groupsize = weights.get_tensor("gptq_groupsize").item() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + qweight = qweight.to(weights.device) + qzeros = qzeros.to(weights.device) + scales = scales.to(weights.device) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") @@ -93,7 +96,9 @@ def _load_multi_mqa_gptq( kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) + bias = bias.to(weights.device) + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -166,7 +171,7 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) def load_col(config, prefix: str, weights, bias: bool): @@ -181,23 +186,11 @@ def load_col(config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize, device=weights.device)) + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) def load_row(config, prefix: str, weights, bias: bool): quantize = config.quantize - if quantize == "gptq-cuda" and weights.process_group.size() > 1: - g_idx = weights.get_tensor(f"{prefix}.g_idx") - groupsize = weights.get_tensor("gptq_groupsize").item() - - act_order = True - if g_idx is not None: - if torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) or (g_idx == 0).all(): - act_order = False - else: - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - quantize = "gptq" if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -209,12 +202,9 @@ def load_row(config, prefix: str, weights, bias: bool): bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - - if quantize == "gptq-cuda" and not act_order: - weight[3] = None # remove g_idx to indicate to exllama that act-order is not used - + return TensorParallelRowLinear( - get_linear(weight, bias, quantize, device=weights.device), process_group=weights.process_group + get_linear(weight, bias, quantize), process_group=weights.process_group ) @@ -480,7 +470,7 @@ class FlashSantacoderForCausalLM(nn.Module): # Buffers need to be persistent to avoid any bug. self.buffers = {} - if config.quantize == "gptq-cuda": + if config.quantize == "gptq": max_dq_buffer_size = 0 for name, submodule in self.named_modules(): if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index f7c06e0f..aa831ea2 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -392,29 +392,29 @@ def ext_q4_matmul(x, q4, q4_width): class Ex4bitLinear: """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize, device): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): assert bits == 4 - self.device = device - self.qweight = qweight.to(device) - self.qzeros = qzeros.to(device) - self.scales = scales.to(device) + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales self.g_idx = g_idx.cpu() if g_idx is not None else None - self.bias = bias.to(device) if bias is not None else None - - if self.g_idx is not None and (self.g_idx == 0).all(): + self.bias = bias if bias is not None else None + + if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))): self.empty_g_idx = True self.g_idx = None - assert device.type == "cuda" - assert device.index is not None + assert self.device.type == "cuda" + assert self.device.index is not None self.q4 = ext_make_q4( self.qweight, self.qzeros, self.scales, self.g_idx, - device.index + self.device.index ) self.height = qweight.shape[0] * 8 diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8f5b6763..122fb884 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,7 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear from typing import Optional - +from loguru import logger # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -131,7 +131,7 @@ class Linear8bitLt(nn.Module): return out -def get_linear(weight, bias, quantize, device = None): +def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": @@ -145,30 +145,24 @@ def get_linear(weight, bias, quantize, device = None): linear.bias = nn.Parameter(bias) elif quantize == "gptq": try: - qweight, qzeros, scales, g_idx, bits, groupsize = weight + qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight except Exception: raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, - bias, - bits, - groupsize, - ) - elif quantize == "gptq-cuda": - try: - qweight, qzeros, scales, g_idx, bits, groupsize = weight - except Exception: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." + if use_triton_kernel: + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, ) - - linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize, device) + else: + linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -193,12 +187,12 @@ class TensorParallelHead(SuperLayer): weight = weights.get_sharded(f"{prefix}.weight", dim=0) # GPTQ doesn't quantize heads (nor embeddings) - if config.quantize in ["gptq", "gptq-cuda"]: + if config.quantize == "gptq": quantize = None else: quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize, device=weights.device), + get_linear(weight, bias=None, quantize=quantize), process_group=weights.process_group, ) @@ -254,7 +248,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize, device=weights.device) + linear = get_linear(weight, bias, config.quantize) return cls(linear) @@ -273,7 +267,7 @@ class TensorParallelRowLinear(SuperLayer): else: bias = None return cls( - get_linear(weight, bias, config.quantize, device=weights.device), + get_linear(weight, bias, config.quantize), process_group=weights.process_group, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9e4bde86..48a22785 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List, Dict, Optional from safetensors import safe_open import torch - +from loguru import logger class Weights: def __init__( @@ -99,7 +99,7 @@ class Weights: return tensor def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize in ["gptq", "gptq-cuda"]: + if quantize == "gptq": try: qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1) except RuntimeError: @@ -114,20 +114,33 @@ class Weights: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() - weight = [qweight, qzeros, scales, g_idx, bits, groupsize] + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize in ["gptq", "gptq-cuda"]: + def get_multi_weights_row(self, prefix: str, quantize: str): + if quantize == "gptq": + use_triton_kernel = False + if self.process_group.size() > 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + groupsize = self.get_tensor("gptq_groupsize").item() + + if g_idx is not None: + if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_triton_kernel = True + try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") - if quantize == "gptq": + if use_triton_kernel: + # The triton kernel reorders the scales/zero points instead of the weight/activation. + # Thus, each rank needs the full qzeros/scales. qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) @@ -146,7 +159,7 @@ class Weights: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() - weight = [qweight, qzeros, scales, g_idx, bits, groupsize] + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight