Adding scratch space.

This commit is contained in:
Nicolas Patry 2023-10-30 16:33:58 +00:00
parent 024bdb0142
commit fb64ce1040

View File

@ -143,7 +143,7 @@ class QuantLinear(nn.Module):
# self.bias = None # self.bias = None
# def post_init(self, temp_dq): # def post_init(self, temp_dq):
temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size()) temp_dq = ExLlamaV2DeviceTensors(self.qweight.device.index , self.temp_dq_size() + self.temp_fwd_size(4096, 8))
assert self.qweight.device.type == "cuda" assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None assert self.qweight.device.index is not None
self.q_tensors = { self.q_tensors = {
@ -152,7 +152,7 @@ class QuantLinear(nn.Module):
"scales":self.scales, "scales":self.scales,
"g_idx":self.g_idx "g_idx":self.g_idx
} }
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size() + self.temp_fwd_size(4096, 8))
self.q_handle = ext_make_q_matrix( self.q_handle = ext_make_q_matrix(
self.q_tensors, temp_dq self.q_tensors, temp_dq
) )