diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index 1e0be490..dc8353f3 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -143,7 +143,7 @@ class QuantLinear(nn.Module): # self.bias = None # def post_init(self, temp_dq): - temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size()) + temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8)) assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None self.q_tensors = { @@ -152,7 +152,7 @@ class QuantLinear(nn.Module): "scales":self.scales, "g_idx":self.g_idx } - temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8)) self.q_handle = ext_make_q_matrix( self.q_tensors, temp_dq )