mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Update exllamav2 (illegal address issued)
This commit is contained in:
parent
f96d997494
commit
024bdb0142
@ -7,9 +7,9 @@ setup(
|
|||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="exllamav2_kernels",
|
name="exllamav2_kernels",
|
||||||
sources=[
|
sources=[
|
||||||
"autogptq_extension/exllamav2/ext.cpp",
|
"exllamav2_kernels/ext.cpp",
|
||||||
"autogptq_extension/exllamav2/cuda/q_matrix.cu",
|
"exllamav2_kernels/cuda/q_matrix.cu",
|
||||||
"autogptq_extension/exllamav2/cuda/q_gemm.cu",
|
"exllamav2_kernels/cuda/q_gemm.cu",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -45,6 +45,15 @@ __all__ = [
|
|||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
from text_generation_server.models.flash_llama import (
|
||||||
|
FlashLlama,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.flash_santacoder import (
|
||||||
|
FlashSantacoderSharded,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
|
@ -87,53 +87,63 @@ class QuantLinear(nn.Module):
|
|||||||
|
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
|
||||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||||
if trainable:
|
# if trainable:
|
||||||
raise NotImplementedError("Exllamav2 kernel does not support training.")
|
# raise NotImplementedError("Exllamav2 kernel does not support training.")
|
||||||
|
|
||||||
self.q_handle = None
|
self.q_handle = None
|
||||||
self.q_tensors = None
|
self.q_tensors = None
|
||||||
self.padding = - outfeatures % 32
|
# self.padding = - outfeatures % 32
|
||||||
|
#
|
||||||
self.infeatures = infeatures
|
# self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures + self.padding
|
# self.outfeatures = outfeatures + self.padding
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.group_size = group_size if group_size != -1 else infeatures
|
# self.group_size = group_size if group_size != -1 else infeatures
|
||||||
self.trainable = trainable
|
# self.trainable = trainable
|
||||||
self.maxq = 2 ** self.bits - 1
|
self.maxq = 2 ** self.bits - 1
|
||||||
|
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
|
||||||
assert infeatures % 32 == 0
|
# assert infeatures % 32 == 0
|
||||||
assert infeatures % self.group_size == 0
|
# assert infeatures % self.group_size == 0
|
||||||
assert outfeatures % 32 == 0
|
# assert outfeatures % 32 == 0
|
||||||
|
#
|
||||||
|
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||||
|
# self.register_buffer(
|
||||||
|
# 'qweight',
|
||||||
|
# torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||||
|
# )
|
||||||
|
# self.register_buffer(
|
||||||
|
# 'qzeros',
|
||||||
|
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||||
|
# )
|
||||||
|
# self.register_buffer(
|
||||||
|
# 'scales',
|
||||||
|
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||||
|
# )
|
||||||
|
# self.register_buffer(
|
||||||
|
# 'g_idx',
|
||||||
|
# torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||||
|
# )
|
||||||
|
self.device = qweight.device
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
self.g_idx = g_idx
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
# if bias:
|
||||||
self.register_buffer(
|
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
'qweight',
|
# else:
|
||||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
# self.bias = None
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
'qzeros',
|
|
||||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
'scales',
|
|
||||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
'g_idx',
|
|
||||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
# def post_init(self, temp_dq):
|
||||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size())
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def post_init(self, temp_dq):
|
|
||||||
assert self.qweight.device.type == "cuda"
|
assert self.qweight.device.type == "cuda"
|
||||||
assert self.qweight.device.index is not None
|
assert self.qweight.device.index is not None
|
||||||
self.q_tensors = {
|
self.q_tensors = {
|
||||||
|
Loading…
Reference in New Issue
Block a user