diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index f546f3af..25d90e97 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -96,35 +96,6 @@ def create_exllama_buffers(): layer.post_init(temp_dq) - # assert DEVICE is not None, "call set_device first" - - # if ACT_ORDER: - # # TODO: this should be set to rust side `max_total_tokens`, but TGI - # # does not offer an API to expose this variable to python, as this variable - # # is handled by the client but it appears the model is initialized by the server. - # # An alternative could be to initialize the buffers during warmup. - # # Dummy - # max_total_tokens = 2048 - # else: - # max_total_tokens = 1 - - # # This temp_state buffer is required to reorder X in the act-order case. - # temp_state = torch.zeros( - # (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE - # ) - # temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) - - # # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - # prepare_buffers(DEVICE, temp_state, temp_dq) - - # matmul_recons_thd = 8 - # matmul_fused_remap = False - # matmul_no_half2 = False - # set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - # TEMP_STATE, TEMP_DQ = temp_state, temp_dq - - class QuantLinear(nn.Module): QUANT_TYPE = "exllamav2" @@ -136,43 +107,12 @@ class QuantLinear(nn.Module): if bits != 4: raise ValueError( f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.") - # if trainable: - # raise NotImplementedError("Exllamav2 kernel does not support training.") - self.q_handle = None self.q_tensors = None - # - # self.infeatures = infeatures - # self.outfeatures = outfeatures + self.padding self.bits = bits - # self.group_size = group_size if group_size != -1 else infeatures - # self.trainable = trainable self.maxq = 2 ** self.bits - 1 self.infeatures = qweight.shape[0] // self.bits * 32 self.outfeatures = qweight.shape[1] - - # assert infeatures % 32 == 0 - # assert infeatures % self.group_size == 0 - # assert outfeatures % 32 == 0 - # self.padding = - outfeatures % 32 - - # # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... - # self.register_buffer( - # 'qweight', - # torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32) - # ) - # self.register_buffer( - # 'qzeros', - # torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32) - # ) - # self.register_buffer( - # 'scales', - # torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16) - # ) - # self.register_buffer( - # 'g_idx', - # torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32) - # ) self.device = qweight.device self.qweight = qweight self.qzeros = qzeros @@ -184,11 +124,6 @@ class QuantLinear(nn.Module): FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) LAYERS.append(self) - # if bias: - # self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) - # else: - # self.bias = None - def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None @@ -216,7 +151,7 @@ class QuantLinear(nn.Module): def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8): + def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)