This commit is contained in:
OlivierDehaene 2024-09-30 11:03:13 +02:00
parent 962ccfd5b7
commit 0e31619893
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -226,6 +226,7 @@ class FlashCausalLMBatch(Batch):
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows()
speculate = get_speculate()
position_ids = []
cu_seqlen_prefill = [0]
start_slots = []
@ -280,17 +281,21 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length)
prefix_length = r.prefix_len
postfix_length = prefix_length + 10
assert (
prefix_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}"
if prefix_length == prompt_length:
assert prefix_length > 0
prefix_length -= 1
if prefix_length + postfix_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_length])
postfix_ids = tokenized_input[prefix_length : prefix_length + 10]
postfix_ids = tokenized_input[prefix_length : postfix_length]
# postfix_ids = tokenized_input[prefix_length:]
postfix_length = len(postfix_ids)
@ -1864,8 +1869,8 @@ class FlashCausalLM(Model):
)
):
continue_prefilling = prefix_length + postfix_length < prompt_length
skip_tokens[r.id] = True
if continue_prefilling:
skip_tokens[r.id] = True
# Update prefix length
prefix_length = prefix_length + postfix_length
batch.prefix_lengths[i] = prefix_length
@ -1980,11 +1985,11 @@ class FlashCausalLM(Model):
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((postfix_length,), adapter_index)
torch.full((next_chunk_length,), adapter_index)
)
# Update
cumulative_length += postfix_length
cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots)
device = batch.input_ids.device