Fixing exl2 scratch buffer.

This commit is contained in:
Nicolas Patry 2024-05-31 15:18:44 +00:00
parent 659bd67fec
commit 5b58262fea

View File

@ -147,7 +147,8 @@ def create_exllama_buffers(max_total_tokens: int):
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
@ -216,7 +217,7 @@ class QuantLinear(nn.Module):
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
def scratch_space_fixed(self, max_input_len, max_batch_size):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)