diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 8a5668a5d..ecedd4aad 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -615,6 +615,12 @@ class FlashCausalLMBatch(Batch): max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] + next_token_logits = self.next_token_logits[indices] + speculative_logits = ( + self.speculative_logits[indices] + if self.speculative_logits is not None + else None + ) block_tables_tensor = self.block_tables_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] @@ -696,8 +702,8 @@ class FlashCausalLMBatch(Batch): speculative_ids=speculative_ids, adapter_meta=adapter_meta, hpu_attn_meta=None, - next_token_logits=None, - speculative_logits=None, + next_token_logits=next_token_logits, + speculative_logits=speculative_logits, ) @classmethod @@ -825,8 +831,11 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - # Copy tensors (GPU) - top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # Copy tensors (HPU) + index = torch.tensor( + list(range(start_index, end_index)), device=batch.input_ids.device + ) + top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -834,7 +843,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor + prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) @@ -844,9 +853,6 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - index = torch.tensor( - list(range(start_index, end_index)), device=batch.input_ids.device - ) input_ids.index_copy_(0, index, batch.input_ids) position_ids.index_copy_(0, index, batch.position_ids) slot_indices.index_copy_( diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 6da2b51da..1628a00b7 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -177,7 +177,7 @@ impl Allocator for SimpleAllocator { (required_blocks, repeats) }; - let tokens = tokens as usize; + let mut tokens = tokens as usize; if required_blocks > self.free_blocks.len() as u32 { None } else { @@ -189,6 +189,8 @@ impl Allocator for SimpleAllocator { .split_off(self.free_blocks.len() - required_blocks as usize); if self.is_hpu_device { blocks.sort(); + // need 1 slot for ping-pong optimization + tokens += 1; } let mut slots = Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);