diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bb35886c..4cc285bf 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -573,11 +573,13 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] prefix_ids = [] + prompt_lengths = [] postfix_lengths = [] prefix_lengths = [] prefix_offsets = [] read_offsets = [] + stopping_criterias = [] top_n_tokens = [] adapter_set = set() @@ -595,14 +597,15 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) # Get length - request_input_length = self.postfix_lengths[idx] + request_postfix_length = self.postfix_lengths[idx] prefix_length = self.prefix_lengths[idx] - max_seqlen = max(max_seqlen, request_input_length) + max_seqlen = max(max_seqlen, request_postfix_length) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) - postfix_lengths.append(request_input_length) + prompt_lengths.append(self.prompt_lengths[idx]) + postfix_lengths.append(request_postfix_length) prefix_lengths.append(prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -626,17 +629,17 @@ class FlashCausalLMBatch(Batch): start_slots.append(cumulative_max_length) # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + slot_indices[i] = cumulative_max_length + request_postfix_length - 1 # Set slice slot_filtering_indices[ self.start_slots[idx] : self.start_slots[idx] - + request_input_length + + request_postfix_length + remaining_tokens - 1 ] = True - cumulative_max_length += request_input_length + remaining_tokens - 1 + cumulative_max_length += request_postfix_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) @@ -647,6 +650,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) @@ -683,6 +687,8 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prefix_lengths=prefix_lengths, @@ -732,12 +738,13 @@ class FlashCausalLMBatch(Batch): max_length = max( max_length, max( - input_length + prefix_length + + postfix_length + stopping_criteria.max_new_tokens + speculative_length - stopping_criteria.current_tokens - for input_length, stopping_criteria in zip( - b.postfix_lengths, b.stopping_criterias + for prefix_length, postfix_length, stopping_criteria in zip( + b.prefix_lengths, b.postfix_lengths, b.stopping_criterias ) ), ) @@ -746,6 +753,9 @@ class FlashCausalLMBatch(Batch): position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( + total_batch_size + ) postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( total_batch_size ) @@ -776,6 +786,7 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] prefix_ids = [] + prompt_lengths = [] postfix_lengths = [] prefix_offsets = [] read_offsets = [] @@ -809,6 +820,7 @@ class FlashCausalLMBatch(Batch): input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots @@ -845,6 +857,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) + prompt_lengths.extend(batch.prompt_lengths) postfix_lengths.extend(batch.postfix_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) @@ -898,6 +911,8 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prompt_lengths=prompt_lengths, + prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prefix_offsets=prefix_offsets,