Added AWQ support for FlashLlama models

This commit is contained in:
Abhinav Kulkarni 2023-09-13 11:08:22 +00:00
parent 4cce84301b
commit 00dede8a63
7 changed files with 74 additions and 3 deletions

View File

@ -25,6 +25,7 @@ enum Quantization {
BitsandbytesNF4, BitsandbytesNF4,
BitsandbytesFP4, BitsandbytesFP4,
Gptq, Gptq,
Awq,
} }
impl std::fmt::Display for Quantization { impl std::fmt::Display for Quantization {
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => { Quantization::Gptq => {
write!(f, "gptq") write!(f, "gptq")
} }
Quantization::Awq => {
write!(f, "awq")
}
} }
} }
} }

View File

@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13" xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13" yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
# Custom 4-bit GEMM AWQ kernels
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels

View File

@ -17,6 +17,7 @@ class Quantization(str, Enum):
bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4" bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq" gptq = "gptq"
awq = "awq"
class Dtype(str, Enum): class Dtype(str, Enum):

View File

@ -268,6 +268,10 @@ def get_model(
raise ValueError( raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
if quantize == "awq":
raise ValueError(
"awq quantization is not supported for AutoModel"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError( raise ValueError(
"4bit quantization is not supported for AutoModel" "4bit quantization is not supported for AutoModel"

View File

@ -62,7 +62,7 @@ class FlashLlama(FlashCausalLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize == "gptq": if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)

View File

@ -17,6 +17,7 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
@ -248,6 +249,19 @@ def get_linear(weight, bias, quantize):
bits, bits,
groupsize, groupsize,
) )
elif quantize == "awq":
try:
qweight, qzeros, scales, bits, groupsize = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
in_features = qweight.shape[0]
out_features = qweight.shape[1] * 32 // bits
linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device)
linear.qweight = qweight
linear.qzeros = qzeros
linear.scales = scales
else: else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear return linear
@ -283,8 +297,8 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
# GPTQ doesn't quantize heads (nor embeddings) # GPTQ and AWQ don't quantize heads (nor embeddings)
if config.quantize == "gptq": if config.quantize in ["gptq", "awq"]:
quantize = None quantize = None
else: else:
quantize = config.quantize quantize = config.quantize

View File

@ -150,6 +150,19 @@ class Weights:
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
if quantize == "awq":
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, bits, groupsize)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0] total_size = slice_.get_shape()[0]
@ -194,6 +207,25 @@ class Weights:
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
elif quantize == "awq":
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, bits, groupsize)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
@ -282,6 +314,20 @@ class Weights:
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
weight = (qweight, qzeros, scales, bits, groupsize)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight