diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8a9512c9..1e81e673 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -226,6 +226,7 @@ class FlashCausalLMBatch(Batch): device: torch.device, ) -> "FlashCausalLMBatch": sliding_window = get_sliding_windows() + speculate = get_speculate() position_ids = [] cu_seqlen_prefill = [0] start_slots = [] @@ -280,17 +281,21 @@ class FlashCausalLMBatch(Batch): prompt_lengths.append(prompt_length) prefix_length = r.prefix_len + postfix_length = prefix_length + 10 assert ( prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" if prefix_length == prompt_length: assert prefix_length > 0 prefix_length -= 1 + if prefix_length + postfix_length < prompt_length: + # FIXME: speculate is not supported for context chunking at the moment + assert speculate == 0 # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length : prefix_length + 10] + postfix_ids = tokenized_input[prefix_length : postfix_length] # postfix_ids = tokenized_input[prefix_length:] postfix_length = len(postfix_ids) @@ -1864,8 +1869,8 @@ class FlashCausalLM(Model): ) ): continue_prefilling = prefix_length + postfix_length < prompt_length - skip_tokens[r.id] = True if continue_prefilling: + skip_tokens[r.id] = True # Update prefix length prefix_length = prefix_length + postfix_length batch.prefix_lengths[i] = prefix_length @@ -1980,11 +1985,11 @@ class FlashCausalLM(Model): ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_indices_list.append( - torch.full((postfix_length,), adapter_index) + torch.full((next_chunk_length,), adapter_index) ) # Update - cumulative_length += postfix_length + cumulative_length += next_chunk_length cumulative_slot_tokens += len(request_slots) device = batch.input_ids.device