diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index e9a320d9..ccec1ba6 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -699,7 +699,9 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate( + cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0 + ) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -748,7 +750,10 @@ class FlashCausalLMBatch(Batch): adapter_meta = None adapter_segment_builder = None else: - input_ids = batches[0].input_ids.new_empty(total_batch_size) + if padded_total_bs == batches[0].input_ids.shape[0]: + input_ids = batches[0].input_ids + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) if ( batches[0].position_ids is not None and batches[0].position_ids.dim() == 2 @@ -827,7 +832,9 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) if i > 0: all_input_ids_tensor.index_copy_( - 0, index.to("hpu"), batch.all_input_ids_tensor[:valid_bsize, :] + 0, + index.to(batch.all_input_ids_tensor.device), + batch.all_input_ids_tensor[:valid_bsize, :], ) block_tables_tensor[ @@ -848,9 +855,10 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_( - 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] - ) + if padded_total_bs != batches[0].input_ids.shape[0] or i > 0: + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -2042,6 +2050,7 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: + # TODO speculative decoding handling missing next_token_logprobs = next_token_logprobs.cpu() accepted_ids = accepted_ids.cpu() index = torch.arange( @@ -2052,7 +2061,13 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor.index_copy_( 0, index, batch.all_input_ids_tensor[batch.valid_indices] ) - next_input_ids = next_input_ids[batch.valid_indices] + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + len(batch.valid_indices) + ) + next_input_ids.index_copy_( + 0, index, next_input_ids[batch.valid_indices] + ) + next_input_ids = next_input_ids[:padded_total_bs] next_token_logprobs = next_token_logprobs[batch.valid_indices] accepted_ids = accepted_ids[batch.valid_indices] if speculative_ids is not None: @@ -2122,10 +2137,13 @@ class FlashCausalLM(Model): batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor + index = F.pad( + index, (0, next_input_ids.shape[0] - index.shape[0]), value=0 + ) index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, - batch.all_input_ids_tensor.shape[0], + index.shape[0], dtype=torch.long, device=batch.all_input_ids_tensor.device, ) @@ -2213,7 +2231,18 @@ class FlashCausalLM(Model): htorch.core.mark_step() # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches) + if self.bucketing_ctx is not None: + total_batch_size = 0 + for b in batches: + total_batch_size += len(b) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + total_batch_size + ) + batch = self.batch_type.concatenate( + batches, padded_total_bs=padded_total_bs + ) + else: + batch = self.batch_type.concatenate(batches) else: batch = batches[0] prefill = batch.prefilling