mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Added AWQ support for FlashLlama models
This commit is contained in:
parent
4cce84301b
commit
00dede8a63
@ -25,6 +25,7 @@ enum Quantization {
|
|||||||
BitsandbytesNF4,
|
BitsandbytesNF4,
|
||||||
BitsandbytesFP4,
|
BitsandbytesFP4,
|
||||||
Gptq,
|
Gptq,
|
||||||
|
Awq,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Quantization {
|
impl std::fmt::Display for Quantization {
|
||||||
@ -43,6 +44,9 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
|
Quantization::Awq => {
|
||||||
|
write!(f, "awq")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,3 +73,5 @@ win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and
|
|||||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
xxhash==3.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
# Custom 4-bit GEMM AWQ kernels
|
||||||
|
git+https://github.com/mit-han-lab/llm-awq.git@f084f40bd996f3cf3a0633c1ad7d9d476c318aaa#subdirectory=awq/kernels
|
||||||
|
@ -17,6 +17,7 @@ class Quantization(str, Enum):
|
|||||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
|
awq = "awq"
|
||||||
|
|
||||||
|
|
||||||
class Dtype(str, Enum):
|
class Dtype(str, Enum):
|
||||||
|
@ -268,6 +268,10 @@ def get_model(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
|
if quantize == "awq":
|
||||||
|
raise ValueError(
|
||||||
|
"awq quantization is not supported for AutoModel"
|
||||||
|
)
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"4bit quantization is not supported for AutoModel"
|
"4bit quantization is not supported for AutoModel"
|
||||||
|
@ -62,7 +62,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
if config.quantize == "gptq":
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
@ -17,6 +17,7 @@ except ImportError:
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||||
|
|
||||||
try:
|
try:
|
||||||
major, _minor = torch.cuda.get_device_capability()
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
@ -248,6 +249,19 @@ def get_linear(weight, bias, quantize):
|
|||||||
bits,
|
bits,
|
||||||
groupsize,
|
groupsize,
|
||||||
)
|
)
|
||||||
|
elif quantize == "awq":
|
||||||
|
try:
|
||||||
|
qweight, qzeros, scales, bits, groupsize = weight
|
||||||
|
except Exception:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
|
)
|
||||||
|
in_features = qweight.shape[0]
|
||||||
|
out_features = qweight.shape[1] * 32 // bits
|
||||||
|
linear = WQLinear(w_bit=bits, group_size=groupsize, in_features=in_features, out_features=out_features, bias=bias is not None, dev=qweight.device)
|
||||||
|
linear.qweight = qweight
|
||||||
|
linear.qzeros = qzeros
|
||||||
|
linear.scales = scales
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
@ -283,8 +297,8 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ and AWQ don't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize in ["gptq", "awq"]:
|
||||||
quantize = None
|
quantize = None
|
||||||
else:
|
else:
|
||||||
quantize = config.quantize
|
quantize = config.quantize
|
||||||
|
@ -150,6 +150,19 @@ class Weights:
|
|||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
|
if quantize == "awq":
|
||||||
|
try:
|
||||||
|
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
|
scales = scales.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
bits, groupsize = self._get_gptq_params()
|
||||||
|
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
@ -194,6 +207,25 @@ class Weights:
|
|||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
|
elif quantize == "awq":
|
||||||
|
try:
|
||||||
|
qweight = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
scales = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
bits, groupsize = self._get_gptq_params()
|
||||||
|
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
@ -282,6 +314,20 @@ class Weights:
|
|||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
|
elif quantize == "awq":
|
||||||
|
bits, groupsize = self._get_gptq_params()
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user