mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix
This commit is contained in:
parent
eb6a02a0f1
commit
258ace7cd3
@ -111,9 +111,9 @@ async fn block_allocator_task(
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
}
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
|
@ -180,7 +180,7 @@ impl State {
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = requires_padding
|
||||
let block_allocator = (!requires_padding)
|
||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||
|
||||
Self {
|
||||
|
@ -183,7 +183,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
block_tables = []
|
||||
slots = []
|
||||
cumulative_blocks = 0
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
@ -233,15 +232,13 @@ class FlashCausalLMBatch(Batch):
|
||||
if not r.blocks:
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
request_blocks = [
|
||||
b
|
||||
for b in range(cumulative_blocks, cumulative_blocks + needed_blocks)
|
||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||
]
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
cumulative_blocks += needed_blocks
|
||||
else:
|
||||
request_blocks = r.blocks
|
||||
request_slots = r.slots
|
||||
@ -358,7 +355,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int64, device="cpu"
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
|
Loading…
Reference in New Issue
Block a user