mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Exllamav2 functional.
This commit is contained in:
parent
a61f432599
commit
ff51589332
@ -96,35 +96,6 @@ def create_exllama_buffers():
|
||||
layer.post_init(temp_dq)
|
||||
|
||||
|
||||
# assert DEVICE is not None, "call set_device first"
|
||||
|
||||
# if ACT_ORDER:
|
||||
# # TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# # does not offer an API to expose this variable to python, as this variable
|
||||
# # is handled by the client but it appears the model is initialized by the server.
|
||||
# # An alternative could be to initialize the buffers during warmup.
|
||||
# # Dummy
|
||||
# max_total_tokens = 2048
|
||||
# else:
|
||||
# max_total_tokens = 1
|
||||
|
||||
# # This temp_state buffer is required to reorder X in the act-order case.
|
||||
# temp_state = torch.zeros(
|
||||
# (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
||||
# )
|
||||
# temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||
|
||||
# # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
# prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
# matmul_recons_thd = 8
|
||||
# matmul_fused_remap = False
|
||||
# matmul_no_half2 = False
|
||||
# set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
# TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
@ -136,43 +107,12 @@ class QuantLinear(nn.Module):
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||
# if trainable:
|
||||
# raise NotImplementedError("Exllamav2 kernel does not support training.")
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
#
|
||||
# self.infeatures = infeatures
|
||||
# self.outfeatures = outfeatures + self.padding
|
||||
self.bits = bits
|
||||
# self.group_size = group_size if group_size != -1 else infeatures
|
||||
# self.trainable = trainable
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||
self.outfeatures = qweight.shape[1]
|
||||
|
||||
# assert infeatures % 32 == 0
|
||||
# assert infeatures % self.group_size == 0
|
||||
# assert outfeatures % 32 == 0
|
||||
# self.padding = - outfeatures % 32
|
||||
|
||||
# # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||
# self.register_buffer(
|
||||
# 'qweight',
|
||||
# torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
# )
|
||||
# self.register_buffer(
|
||||
# 'qzeros',
|
||||
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
# )
|
||||
# self.register_buffer(
|
||||
# 'scales',
|
||||
# torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
# )
|
||||
# self.register_buffer(
|
||||
# 'g_idx',
|
||||
# torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
# )
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
@ -184,11 +124,6 @@ class QuantLinear(nn.Module):
|
||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||
LAYERS.append(self)
|
||||
|
||||
# if bias:
|
||||
# self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
# else:
|
||||
# self.bias = None
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
@ -216,7 +151,7 @@ class QuantLinear(nn.Module):
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
|
||||
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user