Adding scratch space.

This commit is contained in:
Nicolas Patry 2023-10-30 16:33:58 +00:00
parent 024bdb0142
commit fb64ce1040

View File

@ -143,7 +143,7 @@ class QuantLinear(nn.Module):
# self.bias = None
# def post_init(self, temp_dq):
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size())
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8))
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
@ -152,7 +152,7 @@ class QuantLinear(nn.Module):
"scales":self.scales,
"g_idx":self.g_idx
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8))
self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq
)