wip, no filter, no concat

This commit is contained in:
OlivierDehaene 2024-09-26 17:10:00 +02:00
parent a85f5ebecd
commit 962ccfd5b7
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 452 additions and 98 deletions

View File

@ -64,6 +64,8 @@ tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
TOKEN_BUDGET = 8
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
@ -144,12 +146,14 @@ class FlashCausalLMBatch(Batch):
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
max_seqlen: int
max_postfix_length: int
max_current_length: int
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices: Optional[torch.Tensor]
prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]]
prefill_tokens: List[Optional[Tokens]]
# Prefixes
prefix_ids: List[List[int]]
@ -257,7 +261,8 @@ class FlashCausalLMBatch(Batch):
prefill_out_cumulative_length = 0
num_blocks = 0
max_seqlen = 0
max_postfix_length = 0
max_current_length = 0
max_length = 0
max_blocks = 0
@ -285,20 +290,21 @@ class FlashCausalLMBatch(Batch):
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_length])
postfix_ids = tokenized_input[prefix_length:]
postfix_ids = tokenized_input[prefix_length : prefix_length + 10]
# postfix_ids = tokenized_input[prefix_length:]
postfix_length = len(postfix_ids)
postfix_lengths.append(postfix_length)
prefix_offsets.append(postfix_length - 5)
read_offsets.append(postfix_length)
prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(
prefix_length, prompt_length, dtype=torch.int32
prefix_length, prefix_length + postfix_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
@ -396,11 +402,12 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += postfix_length
cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, postfix_length)
max_blocks = max(max_blocks, len(request_blocks))
max_postfix_length = max(max_postfix_length, postfix_length)
max_current_length = max(max_current_length, prefix_length + postfix_length)
max_length = max(
max_length,
prefix_length + postfix_length + max_new_tokens + speculative_length,
prompt_length + max_new_tokens + speculative_length,
)
adapter_indices = torch.cat(adapter_indices_list).to(
@ -502,10 +509,12 @@ class FlashCausalLMBatch(Batch):
slots=slots,
prefix_lengths=prefix_lengths,
prefix_lengths_tensor=prefix_lengths_tensor,
max_seqlen=max_seqlen,
max_postfix_length=max_postfix_length,
max_current_length=max_current_length,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
prefill_tokens=[None] * len(pb.requests),
postfix_lengths=postfix_lengths,
postfix_lengths_tensor=postfix_lengths_tensor,
prompt_lengths=prompt_lengths,
@ -565,7 +574,8 @@ class FlashCausalLMBatch(Batch):
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_seqlen = 0
max_postfix_length = 0
max_current_length = 0
requests = []
start_slots = []
@ -579,6 +589,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = []
read_offsets = []
prefill_tokens = []
stopping_criterias = []
top_n_tokens = []
@ -598,15 +609,18 @@ class FlashCausalLMBatch(Batch):
# Get length
request_postfix_length = self.postfix_lengths[idx]
prefix_length = self.prefix_lengths[idx]
max_seqlen = max(max_seqlen, request_postfix_length)
request_prefix_length = self.prefix_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length)
max_current_length = max(
max_current_length, request_prefix_length + request_postfix_length
)
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length)
prefix_lengths.append(prefix_length)
prefix_lengths.append(request_prefix_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
@ -614,6 +628,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
prefill_tokens.append(self.prefill_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
@ -683,10 +698,12 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
max_seqlen=max_seqlen,
max_postfix_length=max_postfix_length,
max_current_length=max_current_length,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_tokens=prefill_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
@ -725,7 +742,8 @@ class FlashCausalLMBatch(Batch):
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
max_postfix_length = 0
max_current_length = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
@ -734,7 +752,8 @@ class FlashCausalLMBatch(Batch):
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
@ -791,6 +810,8 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = []
read_offsets = []
prefill_tokens = []
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
@ -862,6 +883,8 @@ class FlashCausalLMBatch(Batch):
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_tokens.extend(batch.prefill_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias)
@ -907,10 +930,12 @@ class FlashCausalLMBatch(Batch):
prefix_lengths=prefix_lengths,
prefix_lengths_tensor=prefix_lengths_tensor,
slots=slots,
max_seqlen=max_seqlen,
max_postfix_length=max_postfix_length,
max_current_length=max_current_length,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_tokens=prefill_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths,
@ -1416,7 +1441,7 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
@ -1459,7 +1484,7 @@ class FlashCausalLM(Model):
slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor
prefix_lengths_tensor = batch.prefix_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
@ -1608,15 +1633,47 @@ class FlashCausalLM(Model):
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
else:
prefill_logprobs = None
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate()
finished_prefilling = True
next_chunk_lengths = []
if prefill:
# Budget in tokens for the next batch
# We remove next input ids to always have enough space for at least a single decode
# for the remaining requests
batch_budget = TOKEN_BUDGET - len(batch)
for prefix_length, postfix_length, prompt_length in zip(
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths
):
remaining_prefill_tokens = max(
prompt_length - prefix_length - postfix_length, 0
)
if remaining_prefill_tokens > 0:
next_chunk_length = max(
min(remaining_prefill_tokens, batch_budget), 1
)
batch_budget -= next_chunk_length
finished_prefilling = False
else:
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_chunk_lengths.append(next_chunk_length)
# Turn off speculative if some requests are still prefilling
# It makes the logic easier to follow
if prefill and not finished_prefilling:
speculate = 0
speculative_logits = None
else:
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
@ -1624,7 +1681,7 @@ class FlashCausalLM(Model):
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)],
batch.all_input_ids_tensor[:, : batch.max_current_length],
next_token_logits,
speculate,
batch.speculative_ids,
@ -1635,18 +1692,15 @@ class FlashCausalLM(Model):
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
else:
prefill_logprobs = None
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
elif not prefill:
next_position_ids = batch.position_ids
# Cumulative length
@ -1658,6 +1712,7 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.prompt_lengths,
batch.prefix_lengths,
batch.postfix_lengths,
batch.all_input_ids,
@ -1671,6 +1726,7 @@ class FlashCausalLM(Model):
# For each member of the batch
index = 0
for i, (
prompt_length,
prefix_length,
postfix_length,
all_input_ids,
@ -1686,15 +1742,16 @@ class FlashCausalLM(Model):
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
if finished_prefilling:
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
@ -1709,30 +1766,29 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = (
next_input_ids[index]
)
index += 1
# Represent whether this request is still prefilling
# If it is, the tokens we decoded should be ignored
accept_tokens = prefix_length + postfix_length >= prompt_length
cumulative_length += postfix_length
if accept_tokens:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[
i, prefix_length + postfix_length + j
] = next_input_ids[index]
index += 1
cumulative_length += postfix_length
# Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.postfix_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# These values can be updated without a GPU -> CPU sync
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.postfix_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -1743,15 +1799,265 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# Does a GPU <-> CPU sync internally
if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
# Update values if we need to continue prefilling
# This represents the `else` case of the `Update values` if above
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
skip_tokens = {}
if prefill and not finished_prefilling:
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert batch.speculative_ids is None
all_postfix_ids = []
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
start_slots = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
slots = []
adapter_indices_list = []
for i, (
r,
next_token_id,
all_input_ids,
prefix_length,
postfix_length,
prompt_length,
next_chunk_length,
) in enumerate(
zip(
batch.requests,
next_token_ids,
batch.all_input_ids,
batch.prefix_lengths,
batch.postfix_lengths,
batch.prompt_lengths,
next_chunk_lengths,
)
):
continue_prefilling = prefix_length + postfix_length < prompt_length
skip_tokens[r.id] = True
if continue_prefilling:
# Update prefix length
prefix_length = prefix_length + postfix_length
batch.prefix_lengths[i] = prefix_length
# Update postfix length
postfix_length = next_chunk_length
batch.max_postfix_length = max(
batch.max_postfix_length, postfix_length
)
batch.postfix_lengths[i] = postfix_length
# Potentially update max_current_length
current_length = prefix_length + postfix_length
batch.max_current_length = max(
batch.max_current_length, current_length
)
# Get new prompt IDs to prefill
postfix_ids = all_input_ids[
prefix_length : prefix_length + postfix_length
]
# Position ids
request_position_ids = torch.arange(
prefix_length, prefix_length + postfix_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + postfix_length)
request_slots = r.slots[prefix_length:]
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + postfix_length,
dtype=torch.int64,
)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, postfix_length - sliding_window),
cumulative_length + postfix_length,
dtype=torch.int64,
)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(
request_position_ids + cumulative_length
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + postfix_length - 1
)
prefill_cu_outlens.append(
prefill_out_cumulative_length + postfix_length
)
prefill_out_cumulative_length += postfix_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + postfix_length - 1],
dtype=torch.int32,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
else:
# This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id]
# Position_ids
position_ids.append(
torch.tensor(
(prefix_length + postfix_length,), dtype=torch.int32
)
)
# Add this request token
cu_seqlen_prefill.append(cumulative_length + 1)
request_slots = r.slots[prefix_length:]
request_slot_indices = torch.tensor(
(cumulative_slot_tokens + postfix_length,), dtype=torch.int64
)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.tensor(
[cumulative_length], dtype=torch.int64
)
prefill_head_indices.append(
torch.tensor([cumulative_length], dtype=torch.int32)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
all_postfix_ids.extend(postfix_ids)
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((postfix_length,), adapter_index)
)
# Update
cumulative_length += postfix_length
cumulative_slot_tokens += len(request_slots)
device = batch.input_ids.device
batch.start_slots = torch.tensor(start_slots, dtype=torch.int64)
if len(batch) > 1:
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
batch.cu_seqlen_prefill = cu_seqlen_prefill
batch.position_ids = position_ids.to(device)
batch.slot_indices = slot_indices.to(device)
batch.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
batch.input_ids = torch.tensor(
all_postfix_ids, dtype=torch.int64, device=device
)
batch.postfix_lengths_tensor = torch.tensor(
batch.postfix_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
batch.prefill_head_indices = prefill_head_indices
batch.prefill_next_token_indices = prefill_next_token_indices
batch.slots = torch.tensor(slots, dtype=torch.int64, device=device)
batch.prefix_lengths_tensor = torch.tensor(
batch.prefix_lengths, dtype=torch.int32, device=device
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
batch.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=batch.adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.prefix_lengths,
batch.postfix_lengths,
batch.prefix_offsets,
batch.read_offsets,
@ -1770,7 +2076,9 @@ class FlashCausalLM(Model):
index = 0
for i, (
request,
input_length,
prompt_length,
prefix_length,
postfix_length,
prefix_offset,
read_offset,
stopping_criteria,
@ -1783,6 +2091,61 @@ class FlashCausalLM(Model):
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
request_prefill_tokens = batch.prefill_tokens[i]
request_prefill_logprobs = prefill_logprobs[
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
if request_prefill_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] * (
len(prefix_ids) + 1
) + request_prefill_logprobs
prefill_token_ids = prefix_ids + prefill_token_ids
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefix_ids + prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
if request_prefill_tokens is not None:
prefill_tokens = request_prefill_tokens + prefill_tokens
batch.prefill_tokens[i] = prefill_tokens
else:
batch.prefill_tokens[i] = None
# Represent whether this request is still prefilling
# If it is, the tokens we decoded should be ignored
skip_token = skip_tokens.get(request.id, False)
if skip_token:
# Make sure that we do not stop as even though this request did not create a token, it is still
# processing
stopped = False
# Skip the rest of the decoding
# Values were updated before this for loop
continue
# Append next token to all tokens
next_token_texts = []
left = 0
@ -1823,7 +2186,7 @@ class FlashCausalLM(Model):
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if request.id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
@ -1844,31 +2207,6 @@ class FlashCausalLM(Model):
else:
generated_text = None
# Prefill
if prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = (
[float("nan")] * (len(prefix_ids) + 1)
) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefix_ids + prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefix_ids + prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
@ -1896,7 +2234,7 @@ class FlashCausalLM(Model):
generation = Generation(
request.id,
prefill_tokens,
batch.prefill_tokens[i],
Tokens(
_next_token_ids,
_next_token_logprobs,
@ -1917,9 +2255,13 @@ class FlashCausalLM(Model):
)
# Update values
batch.postfix_lengths[i] = input_length + n_accepted_ids
if batch.postfix_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.postfix_lengths[i]
current_postfix_length = postfix_length + n_accepted_ids
batch.max_postfix_length = max(
batch.max_postfix_length, current_postfix_length
)
batch.postfix_lengths[i] = current_postfix_length
current_length = prefix_length + current_postfix_length
batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
@ -1930,9 +2272,13 @@ class FlashCausalLM(Model):
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
if prefill and finished_prefilling:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode

View File

@ -74,6 +74,14 @@ class Tokens:
def __len__(self):
return len(self.token_ids)
def __add__(self, other: "Tokens") -> "Tokens":
return Tokens(
self.token_ids + other.token_ids,
self.logprobs + other.logprobs,
self.texts + other.texts,
self.is_special + other.is_special,
)
@dataclass
class Generation: