mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
pingpong optimization issue fix
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
5ec7f15d0c
commit
bf3987e25e
@ -615,6 +615,12 @@ class FlashCausalLMBatch(Batch):
|
||||
max_slots = max(max_slots, slot_length)
|
||||
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
next_token_logits = self.next_token_logits[indices]
|
||||
speculative_logits = (
|
||||
self.speculative_logits[indices]
|
||||
if self.speculative_logits is not None
|
||||
else None
|
||||
)
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
@ -696,8 +702,8 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=adapter_meta,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
next_token_logits=next_token_logits,
|
||||
speculative_logits=speculative_logits,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -825,8 +831,11 @@ class FlashCausalLMBatch(Batch):
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
|
||||
# Copy tensors (GPU)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
# Copy tensors (HPU)
|
||||
index = torch.tensor(
|
||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||
)
|
||||
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
@ -834,7 +843,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor[
|
||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||
] = batch.block_tables_tensor[:, :max_blocks]
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)
|
||||
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
@ -844,9 +853,6 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
index = torch.tensor(
|
||||
list(range(start_index, end_index)), device=batch.input_ids.device
|
||||
)
|
||||
input_ids.index_copy_(0, index, batch.input_ids)
|
||||
position_ids.index_copy_(0, index, batch.position_ids)
|
||||
slot_indices.index_copy_(
|
||||
|
@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator {
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let mut tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator {
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
if self.is_hpu_device {
|
||||
blocks.sort();
|
||||
// need 1 slot for ping-pong optimization
|
||||
tokens += 1;
|
||||
}
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
Loading…
Reference in New Issue
Block a user