mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
wip, no filter, no concat
This commit is contained in:
parent
a85f5ebecd
commit
962ccfd5b7
@ -64,6 +64,8 @@ tracer = trace.get_tracer(__name__)
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
TOKEN_BUDGET = 8
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int):
|
||||
global SLIDING_WINDOW
|
||||
@ -144,12 +146,14 @@ class FlashCausalLMBatch(Batch):
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
slots: torch.Tensor
|
||||
|
||||
max_seqlen: int
|
||||
max_postfix_length: int
|
||||
max_current_length: int
|
||||
|
||||
# Prefill metadata tensors to efficiently compute logprobs
|
||||
prefill_head_indices: Optional[torch.Tensor]
|
||||
prefill_next_token_indices: Optional[torch.tensor]
|
||||
prefill_cu_outlens: Optional[List[int]]
|
||||
prefill_tokens: List[Optional[Tokens]]
|
||||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
@ -257,7 +261,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
num_blocks = 0
|
||||
max_seqlen = 0
|
||||
max_postfix_length = 0
|
||||
max_current_length = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
@ -285,20 +290,21 @@ class FlashCausalLMBatch(Batch):
|
||||
# Commented as it's costly.
|
||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||
prefix_ids.append(tokenized_input[:prefix_length])
|
||||
postfix_ids = tokenized_input[prefix_length:]
|
||||
postfix_ids = tokenized_input[prefix_length : prefix_length + 10]
|
||||
# postfix_ids = tokenized_input[prefix_length:]
|
||||
|
||||
postfix_length = len(postfix_ids)
|
||||
postfix_lengths.append(postfix_length)
|
||||
|
||||
prefix_offsets.append(postfix_length - 5)
|
||||
read_offsets.append(postfix_length)
|
||||
prefix_offsets.append(prompt_length - 5)
|
||||
read_offsets.append(prompt_length)
|
||||
|
||||
all_postfix_ids.append(postfix_ids)
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
prefix_length, prompt_length, dtype=torch.int32
|
||||
prefix_length, prefix_length + postfix_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
@ -396,11 +402,12 @@ class FlashCausalLMBatch(Batch):
|
||||
# Update
|
||||
cumulative_length += postfix_length
|
||||
cumulative_slot_tokens += slot_tokens
|
||||
max_seqlen = max(max_seqlen, postfix_length)
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_postfix_length = max(max_postfix_length, postfix_length)
|
||||
max_current_length = max(max_current_length, prefix_length + postfix_length)
|
||||
max_length = max(
|
||||
max_length,
|
||||
prefix_length + postfix_length + max_new_tokens + speculative_length,
|
||||
prompt_length + max_new_tokens + speculative_length,
|
||||
)
|
||||
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
@ -502,10 +509,12 @@ class FlashCausalLMBatch(Batch):
|
||||
slots=slots,
|
||||
prefix_lengths=prefix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
max_seqlen=max_seqlen,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_current_length=max_current_length,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
prefill_tokens=[None] * len(pb.requests),
|
||||
postfix_lengths=postfix_lengths,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
prompt_lengths=prompt_lengths,
|
||||
@ -565,7 +574,8 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
max_seqlen = 0
|
||||
max_postfix_length = 0
|
||||
max_current_length = 0
|
||||
|
||||
requests = []
|
||||
start_slots = []
|
||||
@ -579,6 +589,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
prefill_tokens = []
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
@ -598,15 +609,18 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Get length
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
prefix_length = self.prefix_lengths[idx]
|
||||
max_seqlen = max(max_seqlen, request_postfix_length)
|
||||
request_prefix_length = self.prefix_lengths[idx]
|
||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||
max_current_length = max(
|
||||
max_current_length, request_prefix_length + request_postfix_length
|
||||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
postfix_lengths.append(request_postfix_length)
|
||||
prefix_lengths.append(prefix_length)
|
||||
prefix_lengths.append(request_prefix_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
|
||||
@ -614,6 +628,7 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
prefill_tokens.append(self.prefill_tokens[idx])
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
|
||||
@ -683,10 +698,12 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
max_seqlen=max_seqlen,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_current_length=max_current_length,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prefill_tokens=prefill_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
@ -725,7 +742,8 @@ class FlashCausalLMBatch(Batch):
|
||||
total_slots = 0
|
||||
max_blocks = 0
|
||||
max_length = 0
|
||||
max_seqlen = 0
|
||||
max_postfix_length = 0
|
||||
max_current_length = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
@ -734,7 +752,8 @@ class FlashCausalLMBatch(Batch):
|
||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_seqlen = max(max_seqlen, b.max_seqlen)
|
||||
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
|
||||
max_current_length = max(max_current_length, b.max_current_length)
|
||||
max_length = max(
|
||||
max_length,
|
||||
max(
|
||||
@ -791,6 +810,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
prefill_tokens = []
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
fsm_grammar_states = []
|
||||
stopping_criterias = []
|
||||
@ -862,6 +883,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
prefill_tokens.extend(batch.prefill_tokens)
|
||||
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
@ -907,10 +930,12 @@ class FlashCausalLMBatch(Batch):
|
||||
prefix_lengths=prefix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
slots=slots,
|
||||
max_seqlen=max_seqlen,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_current_length=max_current_length,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prefill_tokens=prefill_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
@ -1416,7 +1441,7 @@ class FlashCausalLM(Model):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
@ -1459,7 +1484,7 @@ class FlashCausalLM(Model):
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
@ -1608,15 +1633,47 @@ class FlashCausalLM(Model):
|
||||
if prefill_logprobs
|
||||
else speculative_logits
|
||||
)
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||
len(batch)
|
||||
)
|
||||
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_token_logits = out
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices
|
||||
|
||||
finished_prefilling = True
|
||||
next_chunk_lengths = []
|
||||
if prefill:
|
||||
# Budget in tokens for the next batch
|
||||
# We remove next input ids to always have enough space for at least a single decode
|
||||
# for the remaining requests
|
||||
batch_budget = TOKEN_BUDGET - len(batch)
|
||||
for prefix_length, postfix_length, prompt_length in zip(
|
||||
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths
|
||||
):
|
||||
remaining_prefill_tokens = max(
|
||||
prompt_length - prefix_length - postfix_length, 0
|
||||
)
|
||||
if remaining_prefill_tokens > 0:
|
||||
next_chunk_length = max(
|
||||
min(remaining_prefill_tokens, batch_budget), 1
|
||||
)
|
||||
batch_budget -= next_chunk_length
|
||||
finished_prefilling = False
|
||||
else:
|
||||
# Since speculation will be turned off, this is always true
|
||||
next_chunk_length = 1
|
||||
next_chunk_lengths.append(next_chunk_length)
|
||||
|
||||
# Turn off speculative if some requests are still prefilling
|
||||
# It makes the logic easier to follow
|
||||
if prefill and not finished_prefilling:
|
||||
speculate = 0
|
||||
speculative_logits = None
|
||||
else:
|
||||
speculate = get_speculate()
|
||||
|
||||
(
|
||||
next_input_ids,
|
||||
next_token_logprobs,
|
||||
@ -1624,7 +1681,7 @@ class FlashCausalLM(Model):
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)],
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
@ -1635,18 +1692,15 @@ class FlashCausalLM(Model):
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||
len(batch)
|
||||
)
|
||||
elif not prefill:
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
# Cumulative length
|
||||
@ -1658,6 +1712,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.prompt_lengths,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.all_input_ids,
|
||||
@ -1671,6 +1726,7 @@ class FlashCausalLM(Model):
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
for i, (
|
||||
prompt_length,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
all_input_ids,
|
||||
@ -1686,6 +1742,7 @@ class FlashCausalLM(Model):
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
out_length = out_end_index - out_start_index
|
||||
|
||||
if finished_prefilling:
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
@ -1709,15 +1766,23 @@ class FlashCausalLM(Model):
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
|
||||
# Represent whether this request is still prefilling
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
accept_tokens = prefix_length + postfix_length >= prompt_length
|
||||
|
||||
if accept_tokens:
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = (
|
||||
next_input_ids[index]
|
||||
)
|
||||
batch.all_input_ids_tensor[
|
||||
i, prefix_length + postfix_length + j
|
||||
] = next_input_ids[index]
|
||||
index += 1
|
||||
|
||||
cumulative_length += postfix_length
|
||||
|
||||
# Update values
|
||||
# These values can be updated without a GPU -> CPU sync
|
||||
if not prefill or (prefill and finished_prefilling):
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
@ -1725,15 +1790,6 @@ class FlashCausalLM(Model):
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
if prefill:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||
@ -1743,15 +1799,265 @@ class FlashCausalLM(Model):
|
||||
# GPU <-> CPU sync
|
||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||
|
||||
# Does a GPU <-> CPU sync internally
|
||||
if prefill and finished_prefilling:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
|
||||
# Update values if we need to continue prefilling
|
||||
# This represents the `else` case of the `Update values` if above
|
||||
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
|
||||
skip_tokens = {}
|
||||
if prefill and not finished_prefilling:
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
# it simplifies everything
|
||||
assert batch.speculative_ids is None
|
||||
|
||||
all_postfix_ids = []
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_slot_tokens = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
slots = []
|
||||
adapter_indices_list = []
|
||||
|
||||
for i, (
|
||||
r,
|
||||
next_token_id,
|
||||
all_input_ids,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
prompt_length,
|
||||
next_chunk_length,
|
||||
) in enumerate(
|
||||
zip(
|
||||
batch.requests,
|
||||
next_token_ids,
|
||||
batch.all_input_ids,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.prompt_lengths,
|
||||
next_chunk_lengths,
|
||||
)
|
||||
):
|
||||
continue_prefilling = prefix_length + postfix_length < prompt_length
|
||||
skip_tokens[r.id] = True
|
||||
if continue_prefilling:
|
||||
# Update prefix length
|
||||
prefix_length = prefix_length + postfix_length
|
||||
batch.prefix_lengths[i] = prefix_length
|
||||
|
||||
# Update postfix length
|
||||
postfix_length = next_chunk_length
|
||||
batch.max_postfix_length = max(
|
||||
batch.max_postfix_length, postfix_length
|
||||
)
|
||||
batch.postfix_lengths[i] = postfix_length
|
||||
|
||||
# Potentially update max_current_length
|
||||
current_length = prefix_length + postfix_length
|
||||
batch.max_current_length = max(
|
||||
batch.max_current_length, current_length
|
||||
)
|
||||
|
||||
# Get new prompt IDs to prefill
|
||||
postfix_ids = all_input_ids[
|
||||
prefix_length : prefix_length + postfix_length
|
||||
]
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
prefix_length, prefix_length + postfix_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + postfix_length)
|
||||
|
||||
request_slots = r.slots[prefix_length:]
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_slot_tokens,
|
||||
cumulative_slot_tokens + postfix_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, postfix_length - sliding_window),
|
||||
cumulative_length + postfix_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
prefill_head_indices.append(
|
||||
request_position_ids + cumulative_length
|
||||
)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + postfix_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(
|
||||
prefill_out_cumulative_length + postfix_length
|
||||
)
|
||||
prefill_out_cumulative_length += postfix_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + postfix_length - 1],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
else:
|
||||
# This request is done prefilling, the new id is the one selected the sampling method
|
||||
postfix_ids = [next_token_id]
|
||||
|
||||
# Position_ids
|
||||
position_ids.append(
|
||||
torch.tensor(
|
||||
(prefix_length + postfix_length,), dtype=torch.int32
|
||||
)
|
||||
)
|
||||
|
||||
# Add this request token
|
||||
cu_seqlen_prefill.append(cumulative_length + 1)
|
||||
|
||||
request_slots = r.slots[prefix_length:]
|
||||
request_slot_indices = torch.tensor(
|
||||
(cumulative_slot_tokens + postfix_length,), dtype=torch.int64
|
||||
)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.tensor(
|
||||
[cumulative_length], dtype=torch.int64
|
||||
)
|
||||
|
||||
prefill_head_indices.append(
|
||||
torch.tensor([cumulative_length], dtype=torch.int32)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
all_postfix_ids.extend(postfix_ids)
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(
|
||||
torch.full((postfix_length,), adapter_index)
|
||||
)
|
||||
|
||||
# Update
|
||||
cumulative_length += postfix_length
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
device = batch.input_ids.device
|
||||
batch.start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
if len(batch) > 1:
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
batch.cu_seqlen_prefill = cu_seqlen_prefill
|
||||
batch.position_ids = position_ids.to(device)
|
||||
batch.slot_indices = slot_indices.to(device)
|
||||
batch.prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
batch.input_ids = torch.tensor(
|
||||
all_postfix_ids, dtype=torch.int64, device=device
|
||||
)
|
||||
batch.postfix_lengths_tensor = torch.tensor(
|
||||
batch.postfix_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
batch.prefill_head_indices = prefill_head_indices
|
||||
batch.prefill_next_token_indices = prefill_next_token_indices
|
||||
batch.slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
batch.prefix_lengths_tensor = torch.tensor(
|
||||
batch.prefix_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
)
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
batch.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=batch.adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
)
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.prompt_lengths,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
@ -1770,7 +2076,9 @@ class FlashCausalLM(Model):
|
||||
index = 0
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
prompt_length,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
stopping_criteria,
|
||||
@ -1783,6 +2091,61 @@ class FlashCausalLM(Model):
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Compute logprobs first as, even though we might skip the token,
|
||||
# it can still be required to compute the logprobs
|
||||
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
||||
# this state to be stable
|
||||
if request.id % self.world_size == self.rank:
|
||||
# Prefill
|
||||
if prefill and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
request_prefill_tokens = batch.prefill_tokens[i]
|
||||
|
||||
request_prefill_logprobs = prefill_logprobs[
|
||||
out_start_index : out_end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
|
||||
if request_prefill_tokens is None:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = [float("nan")] * (
|
||||
len(prefix_ids) + 1
|
||||
) + request_prefill_logprobs
|
||||
prefill_token_ids = prefix_ids + prefill_token_ids
|
||||
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefix_ids + prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
if request_prefill_tokens is not None:
|
||||
prefill_tokens = request_prefill_tokens + prefill_tokens
|
||||
|
||||
batch.prefill_tokens[i] = prefill_tokens
|
||||
else:
|
||||
batch.prefill_tokens[i] = None
|
||||
|
||||
# Represent whether this request is still prefilling
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
skip_token = skip_tokens.get(request.id, False)
|
||||
|
||||
if skip_token:
|
||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
||||
# processing
|
||||
stopped = False
|
||||
# Skip the rest of the decoding
|
||||
# Values were updated before this for loop
|
||||
continue
|
||||
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
@ -1823,7 +2186,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if i % self.world_size == self.rank:
|
||||
if request.id % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
@ -1844,31 +2207,6 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if prefill and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
request_prefill_logprobs = (
|
||||
[float("nan")] * (len(prefix_ids) + 1)
|
||||
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefix_ids + prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = Tokens(
|
||||
prefix_ids + prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
all_top_tokens = []
|
||||
for top_token_ids, top_token_logprobs in zip(
|
||||
@ -1896,7 +2234,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
batch.prefill_tokens[i],
|
||||
Tokens(
|
||||
_next_token_ids,
|
||||
_next_token_logprobs,
|
||||
@ -1917,9 +2255,13 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Update values
|
||||
batch.postfix_lengths[i] = input_length + n_accepted_ids
|
||||
if batch.postfix_lengths[i] > batch.max_seqlen:
|
||||
batch.max_seqlen = batch.postfix_lengths[i]
|
||||
current_postfix_length = postfix_length + n_accepted_ids
|
||||
batch.max_postfix_length = max(
|
||||
batch.max_postfix_length, current_postfix_length
|
||||
)
|
||||
batch.postfix_lengths[i] = current_postfix_length
|
||||
current_length = prefix_length + current_postfix_length
|
||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
@ -1930,6 +2272,10 @@ class FlashCausalLM(Model):
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
if prefill and finished_prefilling:
|
||||
# We do not need prefill tensors anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
batch.prefill_cache_indices = None
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
|
@ -74,6 +74,14 @@ class Tokens:
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def __add__(self, other: "Tokens") -> "Tokens":
|
||||
return Tokens(
|
||||
self.token_ids + other.token_ids,
|
||||
self.logprobs + other.logprobs,
|
||||
self.texts + other.texts,
|
||||
self.is_special + other.is_special,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
|
Loading…
Reference in New Issue
Block a user