mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing exl2 scratch buffer.
This commit is contained in:
parent
659bd67fec
commit
5b58262fea
@ -147,7 +147,8 @@ def create_exllama_buffers(max_total_tokens: int):
|
||||
|
||||
# Find the size of the scratch space.
|
||||
scratch_bytes = max(
|
||||
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
|
||||
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
|
||||
for layer in LAYERS
|
||||
)
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
||||
|
||||
@ -216,7 +217,7 @@ class QuantLinear(nn.Module):
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len=4096, max_batch_size=16):
|
||||
def scratch_space_fixed(self, max_input_len, max_batch_size):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user