From 258ace7cd366d27f8750aaf6e283c15434257fcf Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:28:33 +0200 Subject: [PATCH] fix --- router/src/infer/v3/block_allocator.rs | 6 +++--- router/src/infer/v3/queue.rs | 2 +- server/text_generation_server/models/flash_causal_lm.py | 7 ++----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 4a12ec06..7467fd85 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -111,9 +111,9 @@ async fn block_allocator_task( 'slots: for block_id in blocks.repeat(repeats).iter() { for s in (block_id * block_size)..((block_id + 1) * block_size) { slots.push(s); - } - if slots.len() == tokens { - break 'slots; + if slots.len() == tokens { + break 'slots; + } } } Some((blocks, slots)) diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 19e0a23e..0b66142a 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -180,7 +180,7 @@ impl State { speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = requires_padding + let block_allocator = (!requires_padding) .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); Self { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2f3214b4..d8c8838c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -183,7 +183,6 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] - cumulative_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( @@ -233,15 +232,13 @@ class FlashCausalLMBatch(Batch): if not r.blocks: needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) request_blocks = [ - b - for b in range(cumulative_blocks, cumulative_blocks + needed_blocks) + b for b in range(num_blocks, num_blocks + needed_blocks) ] request_slots = [ s for b in request_blocks for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) ] - cumulative_blocks += needed_blocks else: request_blocks = r.blocks request_slots = r.slots @@ -358,7 +355,7 @@ class FlashCausalLMBatch(Batch): slots = torch.tensor(slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int64, device="cpu" + (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)