This commit is contained in:
OlivierDehaene 2024-06-05 11:28:33 +02:00
parent eb6a02a0f1
commit 258ace7cd3
3 changed files with 6 additions and 9 deletions

View File

@ -111,9 +111,9 @@ async fn block_allocator_task(
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
}
if slots.len() == tokens {
break 'slots;
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))

View File

@ -180,7 +180,7 @@ impl State {
speculate: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = requires_padding
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
Self {

View File

@ -183,7 +183,6 @@ class FlashCausalLMBatch(Batch):
block_tables = []
slots = []
cumulative_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
@ -233,15 +232,13 @@ class FlashCausalLMBatch(Batch):
if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
request_blocks = [
b
for b in range(cumulative_blocks, cumulative_blocks + needed_blocks)
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
cumulative_blocks += needed_blocks
else:
request_blocks = r.blocks
request_slots = r.slots
@ -358,7 +355,7 @@ class FlashCausalLMBatch(Batch):
slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int64, device="cpu"
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)