refactor to use prefix/postfix namming + fix all_input_ids_tensor

This commit is contained in:
OlivierDehaene 2024-09-25 14:40:47 +02:00
parent de043b53c4
commit 838756eb18
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
2 changed files with 173 additions and 136 deletions

View File

@ -9,7 +9,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor postfix_lengths: torch.Tensor
prefix_lengths: torch.Tensor prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def __init__( def __init__(
self, self,
input_lengths, postfix_lengths,
prefix_lengths, prefix_lengths,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=None, max_q=None,
max_k=None, max_k=None,
): ):
self.input_lengths = input_lengths self.postfix_lengths = postfix_lengths
self.prefix_lengths = prefix_lengths self.prefix_lengths = prefix_lengths
device = self.input_lengths.device device = self.postfix_lengths.device
shape = self.input_lengths.shape shape = self.postfix_lengths.shape
if cu_seqlen_q is None: if cu_seqlen_q is None:
cu_seqlen_q = torch.arange( cu_seqlen_q = torch.arange(
shape[0] + 1, shape[0] + 1,
@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
# cuda graphs don't like this and this is necessary to clamp within mistral # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping # Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths total = self.postfix_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
@ -59,7 +59,7 @@ else:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor postfix_lengths: torch.Tensor
prefix_lengths: torch.Tensor prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor cu_seqlen_q: torch.Tensor
max_q: int max_q: int

View File

@ -143,9 +143,6 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_seqlen: int
@ -162,8 +159,14 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] postfix_lengths: List[int]
input_lengths_tensor: torch.Tensor postfix_lengths_tensor: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lengths: List[int]
prefix_lengths_tensor: torch.Tensor
prompt_lengths: List[int]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
@ -225,10 +228,13 @@ class FlashCausalLMBatch(Batch):
slot_indices = [] slot_indices = []
prefill_cache_indices = [] prefill_cache_indices = []
input_lengths = [] prefix_lengths = []
postfix_lengths = []
prompt_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
all_postfix_ids = []
prefix_ids = [] prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -257,7 +263,6 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
slots = [] slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
@ -266,37 +271,39 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
orig_input_length = len(tokenized_input) prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
prefix_len = r.prefix_len prefix_length = r.prefix_len
assert ( assert (
prefix_len <= orig_input_length prefix_length <= prompt_length
), f"Prefix {prefix_len} vs input {orig_input_length}" ), f"Prefix {prefix_length} vs input {prompt_length}"
if prefix_len == orig_input_length: if prefix_length == prompt_length:
assert prefix_len > 0 assert prefix_length > 0
prefix_len -= 1 prefix_length -= 1
# Commented as it's costly. # Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}") # log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_length])
tokenized_input = tokenized_input[prefix_len:] postfix_ids = tokenized_input[prefix_length:]
input_length = len(tokenized_input) postfix_length = len(postfix_ids)
input_lengths.append(input_length) postfix_lengths.append(postfix_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(postfix_length - 5)
read_offsets.append(input_length) read_offsets.append(postfix_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
request_position_ids = torch.arange( request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32 prefix_length, prompt_length, dtype=torch.int32
) )
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length) cu_seqlen_prefill.append(cumulative_length + postfix_length)
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
@ -309,7 +316,7 @@ class FlashCausalLMBatch(Batch):
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index)) adapter_indices_list.append(torch.full((postfix_length,), adapter_index))
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
# Paged attention # Paged attention
@ -318,11 +325,11 @@ class FlashCausalLMBatch(Batch):
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks. # Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the # Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present). # cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
@ -338,19 +345,19 @@ class FlashCausalLMBatch(Batch):
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots[ request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length prefix_length: #: orig_input_length + max_new_tokens + speculative_length
] ]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots) slots.extend(request_slots)
prefix_lens.append(prefix_len) prefix_lengths.append(prefix_length)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_slot_tokens) start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_slot_tokens, cumulative_slot_tokens,
cumulative_slot_tokens + input_length, cumulative_slot_tokens + postfix_length,
dtype=torch.int64, dtype=torch.int64,
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -358,8 +365,8 @@ class FlashCausalLMBatch(Batch):
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
if sliding_window is not None: if sliding_window is not None:
request_prefill_cache_indices = torch.arange( request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window), cumulative_length + max(0, postfix_length - sliding_window),
cumulative_length + input_length, cumulative_length + postfix_length,
dtype=torch.int64, dtype=torch.int64,
) )
prefill_cache_indices.append(request_prefill_cache_indices) prefill_cache_indices.append(request_prefill_cache_indices)
@ -370,14 +377,16 @@ class FlashCausalLMBatch(Batch):
if r.prefill_logprobs: if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length) prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1 prefill_out_cumulative_length + postfix_length - 1
) )
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_cu_outlens.append(
prefill_out_cumulative_length += input_length prefill_out_cumulative_length + postfix_length
)
prefill_out_cumulative_length += postfix_length
else: else:
prefill_head_indices.append( prefill_head_indices.append(
torch.tensor( torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32 [cumulative_length + postfix_length - 1], dtype=torch.int32
) )
) )
prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_next_token_indices.append(prefill_out_cumulative_length)
@ -385,12 +394,13 @@ class FlashCausalLMBatch(Batch):
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
# Update # Update
cumulative_length += input_length cumulative_length += postfix_length
cumulative_slot_tokens += slot_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, postfix_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
max_length, input_length + max_new_tokens + speculative_length max_length,
prefix_length + postfix_length + max_new_tokens + speculative_length,
) )
adapter_indices = torch.cat(adapter_indices_list).to( adapter_indices = torch.cat(adapter_indices_list).to(
@ -415,13 +425,13 @@ class FlashCausalLMBatch(Batch):
) )
if len(pb.requests) > 1: if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_postfix_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
if sliding_window is not None: if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices) prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
input_ids = all_input_ids[0] input_ids = all_postfix_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if sliding_window is not None: if sliding_window is not None:
@ -436,8 +446,11 @@ class FlashCausalLMBatch(Batch):
prefill_cache_indices.to(device) if sliding_window is not None else None prefill_cache_indices.to(device) if sliding_window is not None else None
) )
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor( postfix_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device postfix_lengths, dtype=torch.int32, device=device
)
prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
) )
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
@ -470,7 +483,9 @@ class FlashCausalLMBatch(Batch):
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) prefix_lengths_tensor = torch.tensor(
prefix_lengths, dtype=torch.int32, device=device
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -485,14 +500,16 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
prefix_lens=prefix_lens, prefix_lengths=prefix_lengths,
prefix_lens_tensor=prefix_lens_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens, prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
input_lengths_tensor=input_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -556,8 +573,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
input_lengths = [] postfix_lengths = []
prefix_lens = [] prefix_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -578,15 +595,15 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx]) requests.append(self.requests[idx])
# Get length # Get length
request_input_length = self.input_lengths[idx] request_input_length = self.postfix_lengths[idx]
prefix_len = self.prefix_lens[idx] prefix_length = self.prefix_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length) postfix_lengths.append(request_input_length)
prefix_lens.append(prefix_len) prefix_lengths.append(prefix_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
@ -629,9 +646,9 @@ class FlashCausalLMBatch(Batch):
adapter_indices = self.adapter_meta.adapter_indices[indices] adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
prefix_lens_tensor = self.prefix_lens_tensor[indices] prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = ( speculative_ids = (
@ -666,10 +683,10 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
input_lengths_tensor=input_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
prefix_lens=prefix_lens, prefix_lengths=prefix_lengths,
prefix_lens_tensor=prefix_lens_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -720,7 +737,7 @@ class FlashCausalLMBatch(Batch):
+ speculative_length + speculative_length
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
for input_length, stopping_criteria in zip( for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias b.postfix_lengths, b.stopping_criterias
) )
), ),
) )
@ -729,13 +746,15 @@ class FlashCausalLMBatch(Batch):
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots) slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size) prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
total_batch_size
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
@ -753,11 +772,11 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
prefix_lens = [] prefix_lengths = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
input_lengths = [] postfix_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -790,7 +809,7 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
@ -817,16 +836,16 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens) prefix_lengths.extend(batch.prefix_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids) prefix_ids.extend(batch.prefix_ids)
input_lengths.extend(batch.input_lengths) postfix_lengths.extend(batch.postfix_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
@ -872,15 +891,15 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens, prefix_lengths=prefix_lengths,
prefix_lens_tensor=prefix_lens_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
input_lengths_tensor=input_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -1100,9 +1119,9 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = [max_s] * bs postfix_lengths = [max_s] * bs
prefix_lengths = [0] * bs prefix_lengths = [0] * bs
input_lengths_tensor = ( postfix_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
@ -1114,8 +1133,8 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
prefix_lens=prefix_lengths, prefix_lengths=prefix_lengths,
) )
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
@ -1143,7 +1162,7 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths_tensor, "postfix_lengths": postfix_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor, "prefix_lengths": prefix_lengths_tensor,
"state": state, "state": state,
"graph": graph, "graph": graph,
@ -1154,12 +1173,12 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor, postfix_lengths_tensor=postfix_lengths_tensor,
state=state, state=state,
prefix_lens_tensor=prefix_lengths_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
): ):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths_tensor, postfix_lengths=postfix_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
@ -1183,7 +1202,7 @@ class FlashCausalLM(Model):
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths_tensor, postfix_lengths=postfix_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
@ -1340,15 +1359,17 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) prefix_lengths_tensor = torch.zeros(
seqlen, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen max_s = seqlen
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
prefix_lengths=prefix_lens_tensor, prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=1, max_q=1,
max_k=seqlen, max_k=seqlen,
@ -1379,7 +1400,7 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1396,11 +1417,11 @@ class FlashCausalLM(Model):
position_ids.unsqueeze(-1).expand(B, new_length) + arange position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1) ).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = ( postfix_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( prefix_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -1421,8 +1442,8 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1444,19 +1465,19 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, postfix_lengths=batch.postfix_lengths,
prefix_lens=batch.prefix_lens, prefix_lengths=batch.prefix_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, postfix_lengths_tensor=postfix_lengths,
prefix_lens_tensor=prefix_lens_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
prefix_lengths=prefix_lens_tensor, prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,
@ -1485,8 +1506,8 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, postfix_lengths=batch.postfix_lengths,
prefix_lens=batch.prefix_lens, prefix_lengths=batch.prefix_lengths,
) )
# assert block_tables.shape[0] >= slots.shape[0] # assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
@ -1499,16 +1520,18 @@ class FlashCausalLM(Model):
# so it doesn't matter if we override it with bogus values. # so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["postfix_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
cuda_graph["prefix_lengths"].zero_() cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor cuda_graph["prefix_lengths"][
: prefix_lengths_tensor.shape[0]
] = prefix_lengths_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"], postfix_lengths_tensor=cuda_graph["postfix_lengths"],
prefix_lens_tensor=cuda_graph["prefix_lengths"], prefix_lengths_tensor=cuda_graph["prefix_lengths"],
state=cuda_graph["state"], state=cuda_graph["state"],
): ):
# Replay the graph # Replay the graph
@ -1586,7 +1609,7 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
speculative_ids, speculative_ids,
) = batch.next_token_chooser( ) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)],
next_token_logits, next_token_logits,
speculate, speculate,
batch.speculative_ids, batch.speculative_ids,
@ -1619,7 +1642,12 @@ class FlashCausalLM(Model):
stopped = True stopped = True
# Zipped iterator # Zipped iterator
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) iterator = zip(
batch.prefix_lengths,
batch.postfix_lengths,
batch.all_input_ids,
accepted_ids,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
@ -1627,10 +1655,15 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): for i, (
prefix_length,
postfix_length,
all_input_ids,
n_accepted_ids,
) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + postfix_length
if prefill: if prefill:
# Indexing metadata # Indexing metadata
@ -1662,16 +1695,18 @@ class FlashCausalLM(Model):
] ]
for j in range(n_accepted_ids): for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = (
next_input_ids[index]
)
index += 1 index += 1
cumulative_length += input_length cumulative_length += postfix_length
# Update values # Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids batch.postfix_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices batch.adapter_meta.adapter_indices = next_adapter_indices
@ -1702,7 +1737,7 @@ class FlashCausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.postfix_lengths,
batch.prefix_offsets, batch.prefix_offsets,
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
@ -1867,9 +1902,9 @@ class FlashCausalLM(Model):
) )
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.postfix_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.postfix_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i] batch.max_seqlen = batch.postfix_lengths[i]
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
@ -1893,8 +1928,8 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths_tensor: torch.Tensor, postfix_lengths_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor, prefix_lengths_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
@ -1905,7 +1940,7 @@ class FlashCausalLM(Model):
use_prefill_with_paged_kv_state, use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) # has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
@ -1914,12 +1949,12 @@ class FlashCausalLM(Model):
), ),
# block_tables=block_tables_to_ragged( # block_tables=block_tables_to_ragged(
# block_tables=block_tables, # block_tables=block_tables,
# input_lengths=input_lengths, # postfix_lengths=postfix_lengths,
# prefix_lens=prefix_lens, # prefix_lengths=prefix_lengths,
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor, input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
@ -1928,10 +1963,10 @@ class FlashCausalLM(Model):
window_left=self.sliding_window, window_left=self.sliding_window,
) )
else: else:
assert input_lengths_tensor is not None assert postfix_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor + prefix_lens_tensor, input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
@ -1943,19 +1978,21 @@ class FlashCausalLM(Model):
def block_tables_to_ragged( def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] *, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int]
) -> torch.Tensor: ) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer.""" """Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens) assert len(postfix_lengths) == len(prefix_lengths)
total_len = sum(input_lengths) + sum(prefix_lens) total_len = sum(postfix_lengths) + sum(prefix_lengths)
block_tables_ragged = torch.empty( block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device total_len, dtype=torch.int32, device=block_tables.device
) )
offset = 0 offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): for i, (input_length, prefix_length) in enumerate(
seq_len = prefix_len + input_length zip(postfix_lengths, prefix_lengths)
):
seq_len = prefix_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len offset += seq_len