This commit is contained in:
OlivierDehaene 2024-09-30 11:03:13 +02:00
parent 962ccfd5b7
commit 0e31619893
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -226,6 +226,7 @@ class FlashCausalLMBatch(Batch):
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows() sliding_window = get_sliding_windows()
speculate = get_speculate()
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
start_slots = [] start_slots = []
@ -280,17 +281,21 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length) prompt_lengths.append(prompt_length)
prefix_length = r.prefix_len prefix_length = r.prefix_len
postfix_length = prefix_length + 10
assert ( assert (
prefix_length <= prompt_length prefix_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}" ), f"Prefix {prefix_length} vs input {prompt_length}"
if prefix_length == prompt_length: if prefix_length == prompt_length:
assert prefix_length > 0 assert prefix_length > 0
prefix_length -= 1 prefix_length -= 1
if prefix_length + postfix_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
# Commented as it's costly. # Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}") # log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_length]) prefix_ids.append(tokenized_input[:prefix_length])
postfix_ids = tokenized_input[prefix_length : prefix_length + 10] postfix_ids = tokenized_input[prefix_length : postfix_length]
# postfix_ids = tokenized_input[prefix_length:] # postfix_ids = tokenized_input[prefix_length:]
postfix_length = len(postfix_ids) postfix_length = len(postfix_ids)
@ -1864,8 +1869,8 @@ class FlashCausalLM(Model):
) )
): ):
continue_prefilling = prefix_length + postfix_length < prompt_length continue_prefilling = prefix_length + postfix_length < prompt_length
skip_tokens[r.id] = True
if continue_prefilling: if continue_prefilling:
skip_tokens[r.id] = True
# Update prefix length # Update prefix length
prefix_length = prefix_length + postfix_length prefix_length = prefix_length + postfix_length
batch.prefix_lengths[i] = prefix_length batch.prefix_lengths[i] = prefix_length
@ -1980,11 +1985,11 @@ class FlashCausalLM(Model):
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append( adapter_indices_list.append(
torch.full((postfix_length,), adapter_index) torch.full((next_chunk_length,), adapter_index)
) )
# Update # Update
cumulative_length += postfix_length cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots) cumulative_slot_tokens += len(request_slots)
device = batch.input_ids.device device = batch.input_ids.device