mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Added AWQ support for FlashLlama models
This commit is contained in:
parent
4cce84301b
commit
00dede8a63
@ -25,6 +25,7 @@ enum Quantization {
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
Awq,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
Quantization::Awq => {
|
||||
write!(f, "awq")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
# Custom 4-bit GEMM AWQ kernels
|
||||
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels
|
||||
|
@ -17,6 +17,7 @@ class Quantization(str, Enum):
|
||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -268,6 +268,10 @@ def get_model(
|
||||
raise ValueError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
if quantize == "awq":
|
||||
raise ValueError(
|
||||
"awq quantization is not supported for AutoModel"
|
||||
)
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
|
@ -62,7 +62,7 @@ class FlashLlama(FlashCausalLM):
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize == "gptq":
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
@ -17,6 +17,7 @@ except ImportError:
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
@ -248,6 +249,19 @@ def get_linear(weight, bias, quantize):
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight, qzeros, scales, bits, groupsize = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
in_features = qweight.shape[0]
|
||||
out_features = qweight.shape[1] * 32 // bits
|
||||
linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device)
|
||||
linear.qweight = qweight
|
||||
linear.qzeros = qzeros
|
||||
linear.scales = scales
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
@ -283,8 +297,8 @@ class TensorParallelHead(SuperLayer):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
# GPTQ doesn't quantize heads (nor embeddings)
|
||||
if config.quantize == "gptq":
|
||||
# GPTQ and AWQ don't quantize heads (nor embeddings)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
quantize = None
|
||||
else:
|
||||
quantize = config.quantize
|
||||
|
@ -150,6 +150,19 @@ class Weights:
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
if quantize == "awq":
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
@ -194,6 +207,25 @@ class Weights:
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
elif quantize == "awq":
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
@ -282,6 +314,20 @@ class Weights:
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
Loading…
Reference in New Issue
Block a user