diff --git a/server/exllamav2_kernels/setup.py b/server/exllamav2_kernels/setup.py index a6255125..518db1df 100644 --- a/server/exllamav2_kernels/setup.py +++ b/server/exllamav2_kernels/setup.py @@ -7,9 +7,9 @@ setup( CUDAExtension( name="exllamav2_kernels", sources=[ - "autogptq_extension/exllamav2/ext.cpp", - "autogptq_extension/exllamav2/cuda/q_matrix.cu", - "autogptq_extension/exllamav2/cuda/q_gemm.cu", + "exllamav2_kernels/ext.cpp", + "exllamav2_kernels/cuda/q_matrix.cu", + "exllamav2_kernels/cuda/q_gemm.cu", ], ) ], diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..bbcfea96 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -45,6 +45,15 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True +from text_generation_server.models.flash_rw import FlashRWSharded +from text_generation_server.models.flash_neox import FlashNeoXSharded +from text_generation_server.models.flash_llama import ( + FlashLlama, +) +from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, +) +from text_generation_server.models.idefics import IDEFICSSharded try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index 2e3b8cd0..1e0be490 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -87,53 +87,63 @@ class QuantLinear(nn.Module): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + # def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): super().__init__() if bits != 4: raise ValueError( f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") - if trainable: - raise NotImplementedError("Exllamav2 kernel does not support training.") + # if trainable: + # raise NotImplementedError("Exllamav2 kernel does not support training.") self.q_handle = None self.q_tensors = None - self.padding = - outfeatures % 32 - - self.infeatures = infeatures - self.outfeatures = outfeatures + self.padding + # self.padding = - outfeatures % 32 + # + # self.infeatures = infeatures + # self.outfeatures = outfeatures + self.padding self.bits = bits - self.group_size = group_size if group_size != -1 else infeatures - self.trainable = trainable + # self.group_size = group_size if group_size != -1 else infeatures + # self.trainable = trainable self.maxq = 2 ** self.bits - 1 + self.infeatures = qweight.shape[0] // self.bits * 32 + self.outfeatures = qweight.shape[1] - assert infeatures % 32 == 0 - assert infeatures % self.group_size == 0 - assert outfeatures % 32 == 0 - - # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... - self.register_buffer( - 'qweight', - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32) - ) - self.register_buffer( - 'qzeros', - torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32) - ) - self.register_buffer( - 'scales', - torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16) - ) - self.register_buffer( - 'g_idx', - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32) - ) + # assert infeatures % 32 == 0 + # assert infeatures % self.group_size == 0 + # assert outfeatures % 32 == 0 + # + # # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... + # self.register_buffer( + # 'qweight', + # torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32) + # ) + # self.register_buffer( + # 'qzeros', + # torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32) + # ) + # self.register_buffer( + # 'scales', + # torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16) + # ) + # self.register_buffer( + # 'g_idx', + # torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32) + # ) + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.g_idx = g_idx + self.bias = bias if bias is not None else None - if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) - else: - self.bias = None + # if bias: + # self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + # else: + # self.bias = None - def post_init(self, temp_dq): + # def post_init(self, temp_dq): + temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size()) assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { @@ -185,4 +195,4 @@ class ExLlamaV2DeviceTensors: size_bytes = ((size_bytes + 127) // 128) * 128 size_half = size_bytes // 2 scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice \ No newline at end of file + return scratch_slice