Exllamav2 functional.

This commit is contained in:
Nicolas Patry 2023-11-23 11:34:22 +00:00
parent a61f432599
commit ff51589332

View File

@ -96,35 +96,6 @@ def create_exllama_buffers():
layer.post_init(temp_dq)
# assert DEVICE is not None, "call set_device first"
# if ACT_ORDER:
# # TODO: this should be set to rust side `max_total_tokens`, but TGI
# # does not offer an API to expose this variable to python, as this variable
# # is handled by the client but it appears the model is initialized by the server.
# # An alternative could be to initialize the buffers during warmup.
# # Dummy
# max_total_tokens = 2048
# else:
# max_total_tokens = 1
# # This temp_state buffer is required to reorder X in the act-order case.
# temp_state = torch.zeros(
# (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
# )
# temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
# prepare_buffers(DEVICE, temp_state, temp_dq)
# matmul_recons_thd = 8
# matmul_fused_remap = False
# matmul_no_half2 = False
# set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
# TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
@ -136,43 +107,12 @@ class QuantLinear(nn.Module):
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
# if trainable:
# raise NotImplementedError("Exllamav2 kernel does not support training.")
self.q_handle = None
self.q_tensors = None
#
# self.infeatures = infeatures
# self.outfeatures = outfeatures + self.padding
self.bits = bits
# self.group_size = group_size if group_size != -1 else infeatures
# self.trainable = trainable
self.maxq = 2 ** self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1]
# assert infeatures % 32 == 0
# assert infeatures % self.group_size == 0
# assert outfeatures % 32 == 0
# self.padding = - outfeatures % 32
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
# self.register_buffer(
# 'qweight',
# torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
# )
# self.register_buffer(
# 'qzeros',
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
# )
# self.register_buffer(
# 'scales',
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
# )
# self.register_buffer(
# 'g_idx',
# torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
# )
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
@ -184,11 +124,6 @@ class QuantLinear(nn.Module):
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
LAYERS.append(self)
# if bias:
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
# else:
# self.bias = None
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
@ -216,7 +151,7 @@ class QuantLinear(nn.Module):
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)