add triton fallback to awq

This commit is contained in:
IlyasMoutawwakil 2024-02-01 13:30:13 +00:00 committed by Nicolas Patry
parent aa2014fc79
commit 3963074ceb
2 changed files with 10 additions and 54 deletions

View File

@ -15,11 +15,11 @@ def pack(imatrix: torch.Tensor, direction: str = "column"):
Returns: Returns:
qmatrix (torch.Tensor): packed matrix of integers qmatrix (torch.Tensor): packed matrix of integers
""" """
shifts = torch.arange(0, 32, 4, device=imatrix.device)
imatrix = imatrix.to(torch.int8) imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
shifts = torch.arange(0, 32, 4, device=imatrix.device)
if direction == "column": if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
@ -59,54 +59,6 @@ def unpack(qmatrix: torch.Tensor, direction: str = "column"):
return imatrix return imatrix
def quantize(fmatrix, scales, zeros, group_size):
"""
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
Args:
fmatrix (torch.Tensor): matrix of 16-bit floats
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
imatrix (torch.Tensor): matrix of 4-bit integers
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = torch.round(
(
fmatrix / scales.repeat_interleave(group_size, dim=0)
+ zeros.repeat_interleave(group_size, dim=0)
)
)
imatrix = imatrix.to(torch.int8) & 0x0F
return imatrix
def dequantize(imatrix, scales, zeros, group_size):
"""
Dequantizes a 4-bit integer matrix into a float matrix.
Args:
imatrix (torch.Tensor): matrix of 4-bit integers
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
fmatrix (torch.Tensor): matrix of 16-bit floats
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = imatrix.to(torch.int8) & 0x0F
fmatrix = (
imatrix - zeros.repeat_interleave(group_size, dim=0)
) * scales.repeat_interleave(group_size, dim=0)
fmatrix = fmatrix.to(torch.float16)
return fmatrix
def apply_order( def apply_order(
imatrix: torch.Tensor, imatrix: torch.Tensor,
direction: str = "column", direction: str = "column",
@ -129,7 +81,7 @@ def apply_order(
return imatrix return imatrix
def fast_awq_to_exllama(qweight, qzeros): def fast_awq_to_gptq(qweight, qzeros):
# awq uses column packing for both weights and zeros # awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column") izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column") iweights = unpack(qweight, direction="column")
@ -137,7 +89,7 @@ def fast_awq_to_exllama(qweight, qzeros):
# Reverse the order of the iweight and izeros tensors # Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference) # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
izeros = izeros - 1 izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros # exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column") qzeros = pack(izeros, direction="column")

View File

@ -25,7 +25,8 @@ HAS_AWQ = True
try: try:
from text_generation_server.utils.awq.quantize.qmodule import WQLinear from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError: except ImportError:
from text_generation_server.utils.awq.pack_utils import fast_awq_to_exllama from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq
HAS_AWQ = False HAS_AWQ = False
try: try:
@ -360,10 +361,13 @@ def get_linear(weight, bias, quantize):
bias=bias is not None, bias=bias is not None,
) )
elif HAS_EXLLAMA: elif HAS_EXLLAMA:
qweight, qzeros = fast_awq_to_exllama(qweight, qzeros) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
linear = ExllamaQuantLinear( linear = ExllamaQuantLinear(
qweight, qzeros, scales, None, bias, bits, groupsize qweight, qzeros, scales, None, bias, bits, groupsize
) )
else:
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize)
else: else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear return linear