mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
current
This commit is contained in:
parent
962ccfd5b7
commit
0e31619893
@ -226,6 +226,7 @@ class FlashCausalLMBatch(Batch):
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window = get_sliding_windows()
|
||||
speculate = get_speculate()
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
start_slots = []
|
||||
@ -280,17 +281,21 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
||||
prefix_length = r.prefix_len
|
||||
postfix_length = prefix_length + 10
|
||||
assert (
|
||||
prefix_length <= prompt_length
|
||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||
if prefix_length == prompt_length:
|
||||
assert prefix_length > 0
|
||||
prefix_length -= 1
|
||||
if prefix_length + postfix_length < prompt_length:
|
||||
# FIXME: speculate is not supported for context chunking at the moment
|
||||
assert speculate == 0
|
||||
|
||||
# Commented as it's costly.
|
||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||
prefix_ids.append(tokenized_input[:prefix_length])
|
||||
postfix_ids = tokenized_input[prefix_length : prefix_length + 10]
|
||||
postfix_ids = tokenized_input[prefix_length : postfix_length]
|
||||
# postfix_ids = tokenized_input[prefix_length:]
|
||||
|
||||
postfix_length = len(postfix_ids)
|
||||
@ -1864,8 +1869,8 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
):
|
||||
continue_prefilling = prefix_length + postfix_length < prompt_length
|
||||
skip_tokens[r.id] = True
|
||||
if continue_prefilling:
|
||||
skip_tokens[r.id] = True
|
||||
# Update prefix length
|
||||
prefix_length = prefix_length + postfix_length
|
||||
batch.prefix_lengths[i] = prefix_length
|
||||
@ -1980,11 +1985,11 @@ class FlashCausalLM(Model):
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(
|
||||
torch.full((postfix_length,), adapter_index)
|
||||
torch.full((next_chunk_length,), adapter_index)
|
||||
)
|
||||
|
||||
# Update
|
||||
cumulative_length += postfix_length
|
||||
cumulative_length += next_chunk_length
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
device = batch.input_ids.device
|
||||
|
Loading…
Reference in New Issue
Block a user