mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
current
This commit is contained in:
parent
962ccfd5b7
commit
0e31619893
@ -226,6 +226,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
|
speculate = get_speculate()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
start_slots = []
|
start_slots = []
|
||||||
@ -280,17 +281,21 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_lengths.append(prompt_length)
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
prefix_length = r.prefix_len
|
prefix_length = r.prefix_len
|
||||||
|
postfix_length = prefix_length + 10
|
||||||
assert (
|
assert (
|
||||||
prefix_length <= prompt_length
|
prefix_length <= prompt_length
|
||||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||||
if prefix_length == prompt_length:
|
if prefix_length == prompt_length:
|
||||||
assert prefix_length > 0
|
assert prefix_length > 0
|
||||||
prefix_length -= 1
|
prefix_length -= 1
|
||||||
|
if prefix_length + postfix_length < prompt_length:
|
||||||
|
# FIXME: speculate is not supported for context chunking at the moment
|
||||||
|
assert speculate == 0
|
||||||
|
|
||||||
# Commented as it's costly.
|
# Commented as it's costly.
|
||||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||||
prefix_ids.append(tokenized_input[:prefix_length])
|
prefix_ids.append(tokenized_input[:prefix_length])
|
||||||
postfix_ids = tokenized_input[prefix_length : prefix_length + 10]
|
postfix_ids = tokenized_input[prefix_length : postfix_length]
|
||||||
# postfix_ids = tokenized_input[prefix_length:]
|
# postfix_ids = tokenized_input[prefix_length:]
|
||||||
|
|
||||||
postfix_length = len(postfix_ids)
|
postfix_length = len(postfix_ids)
|
||||||
@ -1864,8 +1869,8 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
continue_prefilling = prefix_length + postfix_length < prompt_length
|
continue_prefilling = prefix_length + postfix_length < prompt_length
|
||||||
skip_tokens[r.id] = True
|
|
||||||
if continue_prefilling:
|
if continue_prefilling:
|
||||||
|
skip_tokens[r.id] = True
|
||||||
# Update prefix length
|
# Update prefix length
|
||||||
prefix_length = prefix_length + postfix_length
|
prefix_length = prefix_length + postfix_length
|
||||||
batch.prefix_lengths[i] = prefix_length
|
batch.prefix_lengths[i] = prefix_length
|
||||||
@ -1980,11 +1985,11 @@ class FlashCausalLM(Model):
|
|||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
adapter_indices_list.append(
|
adapter_indices_list.append(
|
||||||
torch.full((postfix_length,), adapter_index)
|
torch.full((next_chunk_length,), adapter_index)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += postfix_length
|
cumulative_length += next_chunk_length
|
||||||
cumulative_slot_tokens += len(request_slots)
|
cumulative_slot_tokens += len(request_slots)
|
||||||
|
|
||||||
device = batch.input_ids.device
|
device = batch.input_ids.device
|
||||||
|
Loading…
Reference in New Issue
Block a user