diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 2ae9628a..16a3eb89 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -147,7 +147,8 @@ def create_exllama_buffers(max_total_tokens: int): # Find the size of the scratch space. scratch_bytes = max( - layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS + layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1) + for layer in LAYERS ) temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes) @@ -216,7 +217,7 @@ class QuantLinear(nn.Module): def temp_fwd_size(self, max_input_len, max_batch_size): return self.outfeatures * max_input_len * max_batch_size * 4 + 128 - def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16): + def scratch_space_fixed(self, max_input_len, max_batch_size): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)