mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix
This commit is contained in:
parent
eb6a02a0f1
commit
258ace7cd3
@ -111,11 +111,11 @@ async fn block_allocator_task(
|
|||||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||||
slots.push(s);
|
slots.push(s);
|
||||||
}
|
|
||||||
if slots.len() == tokens {
|
if slots.len() == tokens {
|
||||||
break 'slots;
|
break 'slots;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Some((blocks, slots))
|
Some((blocks, slots))
|
||||||
};
|
};
|
||||||
response_sender.send(allocation).unwrap();
|
response_sender.send(allocation).unwrap();
|
||||||
|
@ -180,7 +180,7 @@ impl State {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = requires_padding
|
let block_allocator = (!requires_padding)
|
||||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -183,7 +183,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
cumulative_blocks = 0
|
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -233,15 +232,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
for b in range(cumulative_blocks, cumulative_blocks + needed_blocks)
|
|
||||||
]
|
]
|
||||||
request_slots = [
|
request_slots = [
|
||||||
s
|
s
|
||||||
for b in request_blocks
|
for b in request_blocks
|
||||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||||
]
|
]
|
||||||
cumulative_blocks += needed_blocks
|
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
request_slots = r.slots
|
||||||
@ -358,7 +355,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int64, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
Loading…
Reference in New Issue
Block a user