rename to cache and input lengths

This commit is contained in:
OlivierDehaene 2024-10-07 15:14:03 +02:00
parent 8188deac22
commit 3924b87a04
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
4 changed files with 210 additions and 222 deletions

View File

@ -9,8 +9,8 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:
postfix_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
max_q: int max_q: int
@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def __init__( def __init__(
self, self,
postfix_lengths, input_lengths,
prefix_lengths, cache_lengths,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=None, max_q=None,
max_k=None, max_k=None,
): ):
self.postfix_lengths = postfix_lengths self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths self.cache_lengths = cache_lengths
device = self.postfix_lengths.device device = self.input_lengths.device
shape = self.postfix_lengths.shape shape = self.input_lengths.shape
if cu_seqlen_q is None: if cu_seqlen_q is None:
cu_seqlen_q = torch.arange( cu_seqlen_q = torch.arange(
shape[0] + 1, shape[0] + 1,
@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
# cuda graphs don't like this and this is necessary to clamp within mistral # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping # Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
total = self.postfix_lengths + self.prefix_lengths total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
@ -59,8 +59,8 @@ else:
@dataclass @dataclass
class Seqlen: class Seqlen:
postfix_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor cache_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor cu_seqlen_q: torch.Tensor
max_q: int max_q: int
max_k: int max_k: int

View File

@ -150,7 +150,7 @@ class FlashCausalLMBatch(Batch):
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slots: Optional[torch.Tensor] slots: Optional[torch.Tensor]
max_postfix_length: int max_input_length: int
max_current_length: int max_current_length: int
# Whether this batch contains at least one request that is prefilling # Whether this batch contains at least one request that is prefilling
@ -181,13 +181,13 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
postfix_lengths: List[int] input_lengths: List[int]
# size [b], containing the number of blocks that can be retrieved from the cache # size [b], containing the number of blocks that can be retrieved from the cache
prefix_lengths: List[int] cache_lengths: List[int]
prompt_lengths: List[int] prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
postfix_lengths_tensor: Optional[torch.Tensor] input_lengths_tensor: Optional[torch.Tensor]
prefix_lengths_tensor: Optional[torch.Tensor] cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
@ -252,8 +252,8 @@ class FlashCausalLMBatch(Batch):
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
speculate = get_speculate() speculate = get_speculate()
prefix_lengths = [] cache_lengths = []
postfix_lengths = [] input_lengths = []
prompt_lengths = [] prompt_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -267,7 +267,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens = [] top_n_tokens = []
num_blocks = 0 num_blocks = 0
max_postfix_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
max_length = 0 max_length = 0
max_blocks = 0 max_blocks = 0
@ -284,28 +284,26 @@ class FlashCausalLMBatch(Batch):
prompt_length = len(tokenized_input) prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length) prompt_lengths.append(prompt_length)
prefix_length = r.prefix_len cache_length = r.prefix_len
postfix_length = r.postfix_len input_length = r.postfix_len
assert ( assert (
prefix_length <= prompt_length cache_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}" ), f"Prefix {cache_length} vs input {prompt_length}"
if prefix_length == prompt_length: if cache_length == prompt_length:
assert False, "unreachable" assert False, "unreachable"
if prefix_length + postfix_length < prompt_length: if cache_length + input_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment # FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0 assert speculate == 0
assert get_support_chunking() assert get_support_chunking()
assert postfix_length > 0 assert input_length > 0
prefix_ids.append(tokenized_input[:prefix_length]) prefix_ids.append(tokenized_input[:cache_length])
postfix_ids = tokenized_input[ postfix_ids = tokenized_input[cache_length : cache_length + input_length]
prefix_length : prefix_length + postfix_length
]
assert ( assert (
len(postfix_ids) == postfix_length len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned" ), "Rust and Python tokenizers are not aligned"
postfix_lengths.append(postfix_length) input_lengths.append(input_length)
prefix_offsets.append(prompt_length - 5) prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length) read_offsets.append(prompt_length)
@ -341,13 +339,13 @@ class FlashCausalLMBatch(Batch):
block_tables.append(request_blocks) block_tables.append(request_blocks)
prefix_lengths.append(prefix_length) cache_lengths.append(cache_length)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
# Update # Update
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_postfix_length = max(max_postfix_length, postfix_length) max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, prefix_length + postfix_length) max_current_length = max(max_current_length, cache_length + input_length)
max_length = max( max_length = max(
max_length, max_length,
prompt_length + max_new_tokens + speculative_length, prompt_length + max_new_tokens + speculative_length,
@ -390,13 +388,13 @@ class FlashCausalLMBatch(Batch):
input_ids=all_postfix_ids, input_ids=all_postfix_ids,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lengths=prefix_lengths, cache_lengths=cache_lengths,
max_postfix_length=max_postfix_length, max_input_length=max_input_length,
max_current_length=max_current_length, max_current_length=max_current_length,
prefilling=True, prefilling=True,
prefilling_mask=[True] * len(pb.requests), prefilling_mask=[True] * len(pb.requests),
prefill_logprob_tokens=[None] * len(pb.requests), prefill_logprob_tokens=[None] * len(pb.requests),
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
prompt_lengths=prompt_lengths, prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
@ -420,8 +418,8 @@ class FlashCausalLMBatch(Batch):
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
prefix_lengths_tensor=None, cache_lengths_tensor=None,
postfix_lengths_tensor=None, input_lengths_tensor=None,
adapter_meta=None, adapter_meta=None,
) )
@ -460,7 +458,7 @@ class FlashCausalLMBatch(Batch):
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64) slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_postfix_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
requests = [] requests = []
@ -470,8 +468,8 @@ class FlashCausalLMBatch(Batch):
input_ids = [] input_ids = []
prompt_lengths = [] prompt_lengths = []
postfix_lengths = [] input_lengths = []
prefix_lengths = [] cache_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -499,19 +497,19 @@ class FlashCausalLMBatch(Batch):
prefilling_mask.append(request_prefilling) prefilling_mask.append(request_prefilling)
# Get length # Get length
request_postfix_length = self.postfix_lengths[idx] request_input_length = self.input_lengths[idx]
request_prefix_length = self.prefix_lengths[idx] request_cache_length = self.cache_lengths[idx]
max_postfix_length = max(max_postfix_length, request_postfix_length) max_input_length = max(max_input_length, request_input_length)
max_current_length = max( max_current_length = max(
max_current_length, request_prefix_length + request_postfix_length max_current_length, request_cache_length + request_input_length
) )
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx]) prompt_lengths.append(self.prompt_lengths[idx])
postfix_lengths.append(request_postfix_length) input_lengths.append(request_input_length)
prefix_lengths.append(request_prefix_length) cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
@ -544,12 +542,12 @@ class FlashCausalLMBatch(Batch):
# Set slice # Set slice
slot_filtering_indices[ slot_filtering_indices[
self.slot_indices[idx] : self.slot_indices[idx] self.slot_indices[idx] : self.slot_indices[idx]
+ request_postfix_length + request_input_length
+ remaining_tokens + remaining_tokens
- 1 - 1
] = True ] = True
cumulative_max_length += request_postfix_length + remaining_tokens - 1 cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
@ -567,17 +565,17 @@ class FlashCausalLMBatch(Batch):
position_ids = None position_ids = None
slot_indices = None slot_indices = None
slots = None slots = None
prefix_lengths_tensor = None cache_lengths_tensor = None
postfix_lengths_tensor = None input_lengths_tensor = None
adapter_meta = None adapter_meta = None
else: else:
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices] adapter_indices = self.adapter_meta.adapter_indices[indices]
postfix_lengths_tensor = self.postfix_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
prefix_lengths_tensor = self.prefix_lengths_tensor[indices] cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor # Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
@ -605,7 +603,7 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
max_postfix_length=max_postfix_length, max_input_length=max_input_length,
max_current_length=max_current_length, max_current_length=max_current_length,
prefilling=self.prefilling, prefilling=self.prefilling,
prefilling_mask=prefilling_mask, prefilling_mask=prefilling_mask,
@ -615,10 +613,10 @@ class FlashCausalLMBatch(Batch):
prefill_logprob_tokens=prefill_logprob_tokens, prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths, prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor, prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
postfix_lengths_tensor=postfix_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_lengths=prefix_lengths, cache_lengths=cache_lengths,
prefix_lengths_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -647,7 +645,7 @@ class FlashCausalLMBatch(Batch):
total_slots = 0 total_slots = 0
max_blocks = 0 max_blocks = 0
max_length = 0 max_length = 0
max_postfix_length = 0 max_input_length = 0
max_current_length = 0 max_current_length = 0
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
@ -659,7 +657,7 @@ class FlashCausalLMBatch(Batch):
speculative_length = ( speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
) )
max_postfix_length = max(max_postfix_length, b.max_postfix_length) max_input_length = max(max_input_length, b.max_input_length)
max_current_length = max(max_current_length, b.max_current_length) max_current_length = max(max_current_length, b.max_current_length)
max_length = max( max_length = max(
max_length, max_length,
@ -680,8 +678,8 @@ class FlashCausalLMBatch(Batch):
position_ids = None position_ids = None
slots = None slots = None
slot_indices = None slot_indices = None
prefix_lengths_tensor = None cache_lengths_tensor = None
postfix_lengths_tensor = None input_lengths_tensor = None
adapter_meta = None adapter_meta = None
adapter_segment_builder = None adapter_segment_builder = None
else: else:
@ -689,10 +687,10 @@ class FlashCausalLMBatch(Batch):
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots) slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size) slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
total_indices_size = sum( total_indices_size = sum(
@ -718,12 +716,12 @@ class FlashCausalLMBatch(Batch):
) )
block_tables = [] block_tables = []
prefix_lengths = [] cache_lengths = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
prompt_lengths = [] prompt_lengths = []
postfix_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -773,9 +771,7 @@ class FlashCausalLMBatch(Batch):
slot_indices[start_index:end_index] = ( slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots batch.slot_indices + cumulative_slots
) )
postfix_lengths_tensor[start_index:end_index] = ( input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
batch.postfix_lengths_tensor
)
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices # Copy over adapter indices
@ -793,9 +789,7 @@ class FlashCausalLMBatch(Batch):
batch.adapter_meta.adapter_segments, batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices, batch.adapter_meta.segment_indices,
) )
prefix_lengths_tensor[start_index:end_index] = ( cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
batch.prefix_lengths_tensor
)
# Update # Update
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
@ -806,12 +800,12 @@ class FlashCausalLMBatch(Batch):
prefilling_mask.extend(batch.prefilling_mask) prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lengths.extend(batch.prefix_lengths) cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids) prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths) prompt_lengths.extend(batch.prompt_lengths)
postfix_lengths.extend(batch.postfix_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
@ -860,10 +854,10 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lengths=prefix_lengths, cache_lengths=cache_lengths,
prefix_lengths_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
slots=slots, slots=slots,
max_postfix_length=max_postfix_length, max_input_length=max_input_length,
max_current_length=max_current_length, max_current_length=max_current_length,
prefilling=prefilling, prefilling=prefilling,
prefilling_mask=prefilling_mask, prefilling_mask=prefilling_mask,
@ -873,8 +867,8 @@ class FlashCausalLMBatch(Batch):
prefill_logprob_tokens=prefill_logprob_tokens, prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths, prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor, prompt_lengths_tensor=prompt_lengths_tensor,
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
postfix_lengths_tensor=postfix_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -918,30 +912,30 @@ class FlashCausalLMBatch(Batch):
for i, ( for i, (
r, r,
prefix_length, cache_length,
postfix_length, input_length,
prompt_length, prompt_length,
request_prefilling, request_prefilling,
blocks, blocks,
) in enumerate( ) in enumerate(
zip( zip(
self.requests, self.requests,
self.prefix_lengths, self.cache_lengths,
self.postfix_lengths, self.input_lengths,
self.prompt_lengths, self.prompt_lengths,
self.prefilling_mask, self.prefilling_mask,
self.block_tables, self.block_tables,
) )
): ):
next_chunk_length = postfix_length next_chunk_length = input_length
# Position ids # Position ids
request_position_ids = torch.arange( request_position_ids = torch.arange(
prefix_length, prefix_length + postfix_length, dtype=torch.int32 cache_length, cache_length + input_length, dtype=torch.int32
) )
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + postfix_length) cu_seqlen_prefill.append(cumulative_length + input_length)
if not r.slots: if not r.slots:
request_slots = [ request_slots = [
@ -952,18 +946,18 @@ class FlashCausalLMBatch(Batch):
else: else:
request_slots = r.slots request_slots = r.slots
request_slots = request_slots[prefix_length:] request_slots = request_slots[cache_length:]
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_slot_tokens, cumulative_slot_tokens,
cumulative_slot_tokens + postfix_length, cumulative_slot_tokens + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
if sliding_window is not None: if sliding_window is not None:
request_prefill_cache_indices = torch.arange( request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, postfix_length - sliding_window), cumulative_length + max(0, input_length - sliding_window),
cumulative_length + postfix_length, cumulative_length + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
@ -976,16 +970,14 @@ class FlashCausalLMBatch(Batch):
if prefill_logprobs: if prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length) prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + postfix_length - 1 prefill_out_cumulative_length + input_length - 1
) )
prefill_cu_outlens.append( prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length + postfix_length prefill_out_cumulative_length += input_length
)
prefill_out_cumulative_length += postfix_length
else: else:
prefill_head_indices.append( prefill_head_indices.append(
torch.tensor( torch.tensor(
[cumulative_length + postfix_length - 1], [cumulative_length + input_length - 1],
dtype=torch.int32, dtype=torch.int32,
) )
) )
@ -1038,8 +1030,8 @@ class FlashCausalLMBatch(Batch):
self.prefill_cache_indices = ( self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None prefill_cache_indices.to(device) if sliding_window is not None else None
) )
self.postfix_lengths_tensor = torch.tensor( self.input_lengths_tensor = torch.tensor(
self.postfix_lengths, dtype=torch.int32, device=device self.input_lengths, dtype=torch.int32, device=device
) )
if all_prefill_logprobs: if all_prefill_logprobs:
@ -1059,8 +1051,8 @@ class FlashCausalLMBatch(Batch):
self.prefill_head_indices = prefill_head_indices self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices self.prefill_next_token_indices = prefill_next_token_indices
self.slots = torch.tensor(slots, dtype=torch.int64, device=device) self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
self.prefix_lengths_tensor = torch.tensor( self.cache_lengths_tensor = torch.tensor(
self.prefix_lengths, dtype=torch.int32, device=device self.cache_lengths, dtype=torch.int32, device=device
) )
adapter_indices = torch.cat(adapter_indices_list).to( adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device dtype=torch.int64, device=device
@ -1276,12 +1268,12 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
postfix_lengths = [max_s] * bs input_lengths = [max_s] * bs
prefix_lengths = [0] * bs cache_lengths = [0] * bs
postfix_lengths_tensor = ( input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange( block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device max_bt, dtype=torch.int32, device=self.device
).repeat(bs) ).repeat(bs)
@ -1290,8 +1282,8 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lengths, cache_lengths=cache_lengths,
) )
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
@ -1319,8 +1311,8 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"postfix_lengths": postfix_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor, "cache_lengths": cache_lengths_tensor,
"state": state, "state": state,
"graph": graph, "graph": graph,
} }
@ -1330,13 +1322,13 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
postfix_lengths_tensor=postfix_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lengths_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
seqlen = Seqlen( seqlen = Seqlen(
postfix_lengths=postfix_lengths_tensor, input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
max_k=max_s, max_k=max_s,
@ -1359,8 +1351,8 @@ class FlashCausalLM(Model):
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen( seqlen = Seqlen(
postfix_lengths=postfix_lengths_tensor, input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
max_k=max_s, max_k=max_s,
@ -1517,8 +1509,8 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
prefix_lengths_tensor = torch.zeros( cache_lengths_tensor = torch.zeros(
seqlen, dtype=torch.int32, device=self.device seqlen, dtype=torch.int32, device=self.device
) )
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
@ -1526,8 +1518,8 @@ class FlashCausalLM(Model):
) )
max_s = seqlen max_s = seqlen
seqlen = Seqlen( seqlen = Seqlen(
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=1, max_q=1,
max_k=seqlen, max_k=seqlen,
@ -1558,7 +1550,7 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1575,11 +1567,11 @@ class FlashCausalLM(Model):
position_ids.unsqueeze(-1).expand(B, new_length) + arange position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1) ).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
postfix_lengths = ( input_lengths = (
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lengths_tensor = ( cache_lengths_tensor = (
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -1600,8 +1592,8 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lengths_tensor = batch.prefix_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1623,19 +1615,19 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
postfix_lengths=batch.postfix_lengths, input_lengths=batch.input_lengths,
prefix_lengths=batch.prefix_lengths, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
postfix_lengths_tensor=postfix_lengths, input_lengths_tensor=input_lengths,
prefix_lengths_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (postfix_lengths + prefix_lengths_tensor).max().item() max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,
@ -1664,8 +1656,8 @@ class FlashCausalLM(Model):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
postfix_lengths=batch.postfix_lengths, input_lengths=batch.input_lengths,
prefix_lengths=batch.prefix_lengths, cache_lengths=batch.cache_lengths,
) )
# assert block_tables.shape[0] >= slots.shape[0] # assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
@ -1678,18 +1670,18 @@ class FlashCausalLM(Model):
# so it doesn't matter if we override it with bogus values. # so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["postfix_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["prefix_lengths"].zero_() cuda_graph["cache_lengths"].zero_()
cuda_graph["prefix_lengths"][ cuda_graph["cache_lengths"][
: prefix_lengths_tensor.shape[0] : cache_lengths_tensor.shape[0]
] = prefix_lengths_tensor ] = cache_lengths_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
postfix_lengths_tensor=cuda_graph["postfix_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lengths_tensor=cuda_graph["prefix_lengths"], cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"], state=cuda_graph["state"],
): ):
# Replay the graph # Replay the graph
@ -1775,13 +1767,13 @@ class FlashCausalLM(Model):
batch_budget = get_max_prefill_tokens() - (len(batch) - 1) batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
# We reverse to prioritize older requests # We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead # zip() is not reversible so reverse the underlying lists instead
for prefix_length, postfix_length, prompt_length in zip( for cache_length, input_length, prompt_length in zip(
reversed(batch.prefix_lengths), reversed(batch.cache_lengths),
reversed(batch.postfix_lengths), reversed(batch.input_lengths),
reversed(batch.prompt_lengths), reversed(batch.prompt_lengths),
): ):
remaining_prefill_tokens = max( remaining_prefill_tokens = max(
prompt_length - prefix_length - postfix_length, 0 prompt_length - cache_length - input_length, 0
) )
if remaining_prefill_tokens > 0: if remaining_prefill_tokens > 0:
next_chunk_length = max( next_chunk_length = max(
@ -1842,8 +1834,8 @@ class FlashCausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.prompt_lengths, batch.prompt_lengths,
batch.prefix_lengths, batch.cache_lengths,
batch.postfix_lengths, batch.input_lengths,
batch.all_input_ids, batch.all_input_ids,
accepted_ids, accepted_ids,
) )
@ -1858,14 +1850,14 @@ class FlashCausalLM(Model):
cumulative_length = 0 cumulative_length = 0
for i, ( for i, (
prompt_length, prompt_length,
prefix_length, cache_length,
postfix_length, input_length,
all_input_ids, all_input_ids,
n_accepted_ids, n_accepted_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + postfix_length end_index = cumulative_length + input_length
if prefill: if prefill:
# Indexing metadata # Indexing metadata
@ -1899,17 +1891,17 @@ class FlashCausalLM(Model):
# Represent whether this request is still prefilling # Represent whether this request is still prefilling
# If it is, the tokens we decoded should be ignored # If it is, the tokens we decoded should be ignored
accept_tokens = prefix_length + postfix_length >= prompt_length accept_tokens = cache_length + input_length >= prompt_length
if accept_tokens: if accept_tokens:
# Only save tokens if we are done prefilling for this request # Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids): for j in range(n_accepted_ids):
batch.all_input_ids_tensor[ batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
i, prefix_length + postfix_length + j next_input_ids[index]
] = next_input_ids[index] )
index += 1 index += 1
cumulative_length += postfix_length cumulative_length += input_length
# Update values # Update values
# These values can be updated without a GPU -> CPU sync # These values can be updated without a GPU -> CPU sync
@ -1917,8 +1909,8 @@ class FlashCausalLM(Model):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.prefix_lengths_tensor += batch.postfix_lengths_tensor batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.postfix_lengths_tensor = accepted_ids batch.input_lengths_tensor = accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices batch.adapter_meta.adapter_indices = next_adapter_indices
@ -1959,24 +1951,24 @@ class FlashCausalLM(Model):
request_prefilling, request_prefilling,
next_token_id, next_token_id,
all_input_ids, all_input_ids,
prefix_length, cache_length,
postfix_length, input_length,
next_chunk_length, next_chunk_length,
) in enumerate( ) in enumerate(
zip( zip(
batch.prefilling_mask, batch.prefilling_mask,
next_token_ids, next_token_ids,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_lengths, batch.cache_lengths,
batch.postfix_lengths, batch.input_lengths,
next_chunk_lengths, next_chunk_lengths,
) )
): ):
if request_prefilling: if request_prefilling:
next_prefix_length = prefix_length + postfix_length next_cache_length = cache_length + input_length
# Get new prompt IDs to prefill # Get new prompt IDs to prefill
postfix_ids = all_input_ids[ postfix_ids = all_input_ids[
next_prefix_length : next_prefix_length + next_chunk_length next_cache_length : next_cache_length + next_chunk_length
] ]
else: else:
# This request is done prefilling, the new id is the one selected the sampling method # This request is done prefilling, the new id is the one selected the sampling method
@ -1996,8 +1988,8 @@ class FlashCausalLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.prompt_lengths, batch.prompt_lengths,
batch.prefix_lengths, batch.cache_lengths,
batch.postfix_lengths, batch.input_lengths,
batch.prefix_offsets, batch.prefix_offsets,
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
@ -2012,15 +2004,15 @@ class FlashCausalLM(Model):
batch_top_token_logprobs, batch_top_token_logprobs,
) )
# Reset max_postfix_length # Reset max_input_length
batch.max_postfix_length = 0 batch.max_input_length = 0
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, ( for i, (
request, request,
prompt_length, prompt_length,
prefix_length, cache_length,
postfix_length, input_length,
prefix_offset, prefix_offset,
read_offset, read_offset,
stopping_criteria, stopping_criteria,
@ -2084,9 +2076,9 @@ class FlashCausalLM(Model):
# Make sure that we do not stop as even though this request did not create a token, it is still # Make sure that we do not stop as even though this request did not create a token, it is still
# processing # processing
stopped = False stopped = False
new_postfix_length = next_chunk_lengths[i] new_input_length = next_chunk_lengths[i]
else: else:
new_postfix_length = n_accepted_ids new_input_length = n_accepted_ids
# Append next token to all tokens # Append next token to all tokens
next_token_texts = [] next_token_texts = []
left = 0 left = 0
@ -2198,14 +2190,12 @@ class FlashCausalLM(Model):
) )
# Update values # Update values
current_prefix_length = prefix_length + postfix_length current_cache_length = cache_length + input_length
batch.prefix_lengths[i] = current_prefix_length batch.cache_lengths[i] = current_cache_length
current_postfix_length = new_postfix_length current_input_length = new_input_length
batch.max_postfix_length = max( batch.max_input_length = max(batch.max_input_length, current_input_length)
batch.max_postfix_length, current_postfix_length batch.input_lengths[i] = current_input_length
) current_length = current_cache_length + current_input_length
batch.postfix_lengths[i] = current_postfix_length
current_length = current_prefix_length + current_postfix_length
batch.max_current_length = max(batch.max_current_length, current_length) batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
@ -2235,8 +2225,8 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
postfix_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
prefix_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
@ -2247,7 +2237,7 @@ class FlashCausalLM(Model):
use_prefill_with_paged_kv_state, use_prefill_with_paged_kv_state,
) )
# has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths) # has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
@ -2256,12 +2246,12 @@ class FlashCausalLM(Model):
), ),
# block_tables=block_tables_to_ragged( # block_tables=block_tables_to_ragged(
# block_tables=block_tables, # block_tables=block_tables,
# postfix_lengths=postfix_lengths, # input_lengths=input_lengths,
# prefix_lengths=prefix_lengths, # cache_lengths=cache_lengths,
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
@ -2270,10 +2260,10 @@ class FlashCausalLM(Model):
window_left=self.sliding_window, window_left=self.sliding_window,
) )
else: else:
assert postfix_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor, input_lengths=input_lengths_tensor + cache_lengths_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
@ -2285,21 +2275,19 @@ class FlashCausalLM(Model):
def block_tables_to_ragged( def block_tables_to_ragged(
*, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int] *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
) -> torch.Tensor: ) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer.""" """Convert block table to ragged format compatible with FlashInfer."""
assert len(postfix_lengths) == len(prefix_lengths) assert len(input_lengths) == len(cache_lengths)
total_len = sum(postfix_lengths) + sum(prefix_lengths) total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty( block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device total_len, dtype=torch.int32, device=block_tables.device
) )
offset = 0 offset = 0
for i, (input_length, prefix_length) in enumerate( for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
zip(postfix_lengths, prefix_lengths) seq_len = cache_length + input_length
):
seq_len = prefix_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len offset += seq_len

View File

@ -285,7 +285,7 @@ class MllamaCausalLM(VlmCausalLM):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,

View File

@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM):
position_ids.unsqueeze(-1).expand(B, new_length) + arange position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1) ).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
postfix_lengths = ( input_lengths = (
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lengths_tensor = ( cache_lengths_tensor = (
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
postfix_lengths = batch.postfix_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lengths_tensor = batch.prefix_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -359,19 +359,19 @@ class VlmCausalLM(FlashCausalLM):
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
postfix_lengths=batch.postfix_lengths, input_lengths=batch.input_lengths,
prefix_lengths=batch.prefix_lengths, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
postfix_lengths_tensor=postfix_lengths, input_lengths_tensor=input_lengths,
prefix_lengths_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (postfix_lengths + prefix_lengths_tensor).max().item() max_k = (input_lengths + cache_lengths_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
postfix_lengths=postfix_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,
@ -408,8 +408,8 @@ class VlmCausalLM(FlashCausalLM):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
postfix_lengths=batch.postfix_lengths, input_lengths=batch.input_lengths,
prefix_lengths=batch.prefix_lengths, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
@ -418,18 +418,18 @@ class VlmCausalLM(FlashCausalLM):
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["postfix_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["prefix_lengths"].zero_() cuda_graph["cache_lengths"].zero_()
cuda_graph["prefix_lengths"][ cuda_graph["cache_lengths"][
: prefix_lengths_tensor.shape[0] : cache_lengths_tensor.shape[0]
] = prefix_lengths_tensor ] = cache_lengths_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
postfix_lengths_tensor=cuda_graph["postfix_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lengths_tensor=cuda_graph["prefix_lengths"], cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"], state=cuda_graph["state"],
): ):
# Replay the graph # Replay the graph