mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add triton fallback to awq
This commit is contained in:
parent
aa2014fc79
commit
3963074ceb
@ -15,11 +15,11 @@ def pack(imatrix: torch.Tensor, direction: str = "column"):
|
|||||||
Returns:
|
Returns:
|
||||||
qmatrix (torch.Tensor): packed matrix of integers
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
"""
|
"""
|
||||||
shifts = torch.arange(0, 32, 4, device=imatrix.device)
|
|
||||||
|
|
||||||
imatrix = imatrix.to(torch.int8)
|
imatrix = imatrix.to(torch.int8)
|
||||||
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
|
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
|
||||||
|
|
||||||
|
shifts = torch.arange(0, 32, 4, device=imatrix.device)
|
||||||
|
|
||||||
if direction == "column":
|
if direction == "column":
|
||||||
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
@ -59,54 +59,6 @@ def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
|||||||
return imatrix
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
def quantize(fmatrix, scales, zeros, group_size):
|
|
||||||
"""
|
|
||||||
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
|
|
||||||
Args:
|
|
||||||
fmatrix (torch.Tensor): matrix of 16-bit floats
|
|
||||||
scales (torch.Tensor): matrix of 16-bit floats
|
|
||||||
zeros (torch.Tensor): matrix of 4-bit integers
|
|
||||||
group_size (int): group size
|
|
||||||
Returns:
|
|
||||||
imatrix (torch.Tensor): matrix of 4-bit integers
|
|
||||||
"""
|
|
||||||
zeros = zeros.to(torch.int8) & 0x0F
|
|
||||||
|
|
||||||
imatrix = torch.round(
|
|
||||||
(
|
|
||||||
fmatrix / scales.repeat_interleave(group_size, dim=0)
|
|
||||||
+ zeros.repeat_interleave(group_size, dim=0)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
imatrix = imatrix.to(torch.int8) & 0x0F
|
|
||||||
|
|
||||||
return imatrix
|
|
||||||
|
|
||||||
|
|
||||||
def dequantize(imatrix, scales, zeros, group_size):
|
|
||||||
"""
|
|
||||||
Dequantizes a 4-bit integer matrix into a float matrix.
|
|
||||||
Args:
|
|
||||||
imatrix (torch.Tensor): matrix of 4-bit integers
|
|
||||||
scales (torch.Tensor): matrix of 16-bit floats
|
|
||||||
zeros (torch.Tensor): matrix of 4-bit integers
|
|
||||||
group_size (int): group size
|
|
||||||
Returns:
|
|
||||||
fmatrix (torch.Tensor): matrix of 16-bit floats
|
|
||||||
"""
|
|
||||||
zeros = zeros.to(torch.int8) & 0x0F
|
|
||||||
imatrix = imatrix.to(torch.int8) & 0x0F
|
|
||||||
|
|
||||||
fmatrix = (
|
|
||||||
imatrix - zeros.repeat_interleave(group_size, dim=0)
|
|
||||||
) * scales.repeat_interleave(group_size, dim=0)
|
|
||||||
|
|
||||||
fmatrix = fmatrix.to(torch.float16)
|
|
||||||
|
|
||||||
return fmatrix
|
|
||||||
|
|
||||||
|
|
||||||
def apply_order(
|
def apply_order(
|
||||||
imatrix: torch.Tensor,
|
imatrix: torch.Tensor,
|
||||||
direction: str = "column",
|
direction: str = "column",
|
||||||
@ -129,7 +81,7 @@ def apply_order(
|
|||||||
return imatrix
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
def fast_awq_to_exllama(qweight, qzeros):
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
# awq uses column packing for both weights and zeros
|
# awq uses column packing for both weights and zeros
|
||||||
izeros = unpack(qzeros, direction="column")
|
izeros = unpack(qzeros, direction="column")
|
||||||
iweights = unpack(qweight, direction="column")
|
iweights = unpack(qweight, direction="column")
|
||||||
@ -137,7 +89,7 @@ def fast_awq_to_exllama(qweight, qzeros):
|
|||||||
# Reverse the order of the iweight and izeros tensors
|
# Reverse the order of the iweight and izeros tensors
|
||||||
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
izeros = izeros - 1
|
izeros = izeros - 1
|
||||||
# exllama uses row packing for weights and column packing for zeros
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
qzeros = pack(izeros, direction="column")
|
qzeros = pack(izeros, direction="column")
|
||||||
|
@ -25,7 +25,8 @@ HAS_AWQ = True
|
|||||||
try:
|
try:
|
||||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from text_generation_server.utils.awq.pack_utils import fast_awq_to_exllama
|
from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq
|
||||||
|
|
||||||
HAS_AWQ = False
|
HAS_AWQ = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -360,10 +361,13 @@ def get_linear(weight, bias, quantize):
|
|||||||
bias=bias is not None,
|
bias=bias is not None,
|
||||||
)
|
)
|
||||||
elif HAS_EXLLAMA:
|
elif HAS_EXLLAMA:
|
||||||
qweight, qzeros = fast_awq_to_exllama(qweight, qzeros)
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
linear = ExllamaQuantLinear(
|
linear = ExllamaQuantLinear(
|
||||||
qweight, qzeros, scales, None, bias, bits, groupsize
|
qweight, qzeros, scales, None, bias, bits, groupsize
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
|
Loading…
Reference in New Issue
Block a user