pingpong optimization issue fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-15 21:56:51 -07:00
parent 5ec7f15d0c
commit bf3987e25e
2 changed files with 17 additions and 9 deletions

View File

@ -615,6 +615,12 @@ class FlashCausalLMBatch(Batch):
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_logits = self.next_token_logits[indices]
speculative_logits = (
self.speculative_logits[indices]
if self.speculative_logits is not None
else None
)
block_tables_tensor = self.block_tables_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
@ -696,8 +702,8 @@ class FlashCausalLMBatch(Batch):
speculative_ids=speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
next_token_logits=next_token_logits,
speculative_logits=speculative_logits,
)
@classmethod
@ -825,8 +831,11 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
# Copy tensors (GPU)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy tensors (HPU)
index = torch.tensor(
list(range(start_index, end_index)), device=batch.input_ids.device
)
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
@ -834,7 +843,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
@ -844,9 +853,6 @@ class FlashCausalLMBatch(Batch):
)
if not prefilling:
index = torch.tensor(
list(range(start_index, end_index)), device=batch.input_ids.device
)
input_ids.index_copy_(0, index, batch.input_ids)
position_ids.index_copy_(0, index, batch.position_ids)
slot_indices.index_copy_(

View File

@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator {
(required_blocks, repeats)
};
let tokens = tokens as usize;
let mut tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator {
.split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device {
blocks.sort();
// need 1 slot for ping-pong optimization
tokens += 1;
}
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);