mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
rename to cache and input lengths
This commit is contained in:
parent
8188deac22
commit
3924b87a04
@ -9,8 +9,8 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
postfix_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
input_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
max_q: int
|
||||
@ -18,16 +18,16 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
postfix_lengths,
|
||||
prefix_lengths,
|
||||
input_lengths,
|
||||
cache_lengths,
|
||||
cu_seqlen_q=None,
|
||||
max_q=None,
|
||||
max_k=None,
|
||||
):
|
||||
self.postfix_lengths = postfix_lengths
|
||||
self.prefix_lengths = prefix_lengths
|
||||
device = self.postfix_lengths.device
|
||||
shape = self.postfix_lengths.shape
|
||||
self.input_lengths = input_lengths
|
||||
self.cache_lengths = cache_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
if cu_seqlen_q is None:
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
@ -43,7 +43,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
total = self.postfix_lengths + self.prefix_lengths
|
||||
total = self.input_lengths + self.cache_lengths
|
||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
@ -59,8 +59,8 @@ else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
postfix_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
input_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: torch.Tensor
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
@ -150,7 +150,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
slots: Optional[torch.Tensor]
|
||||
|
||||
max_postfix_length: int
|
||||
max_input_length: int
|
||||
max_current_length: int
|
||||
|
||||
# Whether this batch contains at least one request that is prefilling
|
||||
@ -181,13 +181,13 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor: torch.Tensor
|
||||
|
||||
# Lengths of all generations present in the batch
|
||||
postfix_lengths: List[int]
|
||||
input_lengths: List[int]
|
||||
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||
prefix_lengths: List[int]
|
||||
cache_lengths: List[int]
|
||||
prompt_lengths: List[int]
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
postfix_lengths_tensor: Optional[torch.Tensor]
|
||||
prefix_lengths_tensor: Optional[torch.Tensor]
|
||||
input_lengths_tensor: Optional[torch.Tensor]
|
||||
cache_lengths_tensor: Optional[torch.Tensor]
|
||||
prompt_lengths_tensor: torch.Tensor
|
||||
|
||||
prefix_offsets: List[Optional[int]]
|
||||
@ -252,8 +252,8 @@ class FlashCausalLMBatch(Batch):
|
||||
) -> "FlashCausalLMBatch":
|
||||
speculate = get_speculate()
|
||||
|
||||
prefix_lengths = []
|
||||
postfix_lengths = []
|
||||
cache_lengths = []
|
||||
input_lengths = []
|
||||
prompt_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
@ -267,7 +267,7 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens = []
|
||||
|
||||
num_blocks = 0
|
||||
max_postfix_length = 0
|
||||
max_input_length = 0
|
||||
max_current_length = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
@ -284,28 +284,26 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_length = len(tokenized_input)
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
||||
prefix_length = r.prefix_len
|
||||
postfix_length = r.postfix_len
|
||||
cache_length = r.prefix_len
|
||||
input_length = r.postfix_len
|
||||
assert (
|
||||
prefix_length <= prompt_length
|
||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||
if prefix_length == prompt_length:
|
||||
cache_length <= prompt_length
|
||||
), f"Prefix {cache_length} vs input {prompt_length}"
|
||||
if cache_length == prompt_length:
|
||||
assert False, "unreachable"
|
||||
if prefix_length + postfix_length < prompt_length:
|
||||
if cache_length + input_length < prompt_length:
|
||||
# FIXME: speculate is not supported for context chunking at the moment
|
||||
assert speculate == 0
|
||||
assert get_support_chunking()
|
||||
assert postfix_length > 0
|
||||
assert input_length > 0
|
||||
|
||||
prefix_ids.append(tokenized_input[:prefix_length])
|
||||
postfix_ids = tokenized_input[
|
||||
prefix_length : prefix_length + postfix_length
|
||||
]
|
||||
prefix_ids.append(tokenized_input[:cache_length])
|
||||
postfix_ids = tokenized_input[cache_length : cache_length + input_length]
|
||||
|
||||
assert (
|
||||
len(postfix_ids) == postfix_length
|
||||
len(postfix_ids) == input_length
|
||||
), "Rust and Python tokenizers are not aligned"
|
||||
postfix_lengths.append(postfix_length)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(prompt_length - 5)
|
||||
read_offsets.append(prompt_length)
|
||||
@ -341,13 +339,13 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
block_tables.append(request_blocks)
|
||||
|
||||
prefix_lengths.append(prefix_length)
|
||||
cache_lengths.append(cache_length)
|
||||
num_blocks += len(request_blocks)
|
||||
|
||||
# Update
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_postfix_length = max(max_postfix_length, postfix_length)
|
||||
max_current_length = max(max_current_length, prefix_length + postfix_length)
|
||||
max_input_length = max(max_input_length, input_length)
|
||||
max_current_length = max(max_current_length, cache_length + input_length)
|
||||
max_length = max(
|
||||
max_length,
|
||||
prompt_length + max_new_tokens + speculative_length,
|
||||
@ -390,13 +388,13 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids=all_postfix_ids,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
prefix_lengths=prefix_lengths,
|
||||
max_postfix_length=max_postfix_length,
|
||||
cache_lengths=cache_lengths,
|
||||
max_input_length=max_input_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=True,
|
||||
prefilling_mask=[True] * len(pb.requests),
|
||||
prefill_logprob_tokens=[None] * len(pb.requests),
|
||||
postfix_lengths=postfix_lengths,
|
||||
input_lengths=input_lengths,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
@ -420,8 +418,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
prefix_lengths_tensor=None,
|
||||
postfix_lengths_tensor=None,
|
||||
cache_lengths_tensor=None,
|
||||
input_lengths_tensor=None,
|
||||
adapter_meta=None,
|
||||
)
|
||||
|
||||
@ -460,7 +458,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
max_postfix_length = 0
|
||||
max_input_length = 0
|
||||
max_current_length = 0
|
||||
|
||||
requests = []
|
||||
@ -470,8 +468,8 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
postfix_lengths = []
|
||||
prefix_lengths = []
|
||||
input_lengths = []
|
||||
cache_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
@ -499,19 +497,19 @@ class FlashCausalLMBatch(Batch):
|
||||
prefilling_mask.append(request_prefilling)
|
||||
|
||||
# Get length
|
||||
request_postfix_length = self.postfix_lengths[idx]
|
||||
request_prefix_length = self.prefix_lengths[idx]
|
||||
max_postfix_length = max(max_postfix_length, request_postfix_length)
|
||||
request_input_length = self.input_lengths[idx]
|
||||
request_cache_length = self.cache_lengths[idx]
|
||||
max_input_length = max(max_input_length, request_input_length)
|
||||
max_current_length = max(
|
||||
max_current_length, request_prefix_length + request_postfix_length
|
||||
max_current_length, request_cache_length + request_input_length
|
||||
)
|
||||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
|
||||
prompt_lengths.append(self.prompt_lengths[idx])
|
||||
postfix_lengths.append(request_postfix_length)
|
||||
prefix_lengths.append(request_prefix_length)
|
||||
input_lengths.append(request_input_length)
|
||||
cache_lengths.append(request_cache_length)
|
||||
prefix_offsets.append(self.prefix_offsets[idx])
|
||||
read_offsets.append(self.read_offsets[idx])
|
||||
|
||||
@ -544,12 +542,12 @@ class FlashCausalLMBatch(Batch):
|
||||
# Set slice
|
||||
slot_filtering_indices[
|
||||
self.slot_indices[idx] : self.slot_indices[idx]
|
||||
+ request_postfix_length
|
||||
+ request_input_length
|
||||
+ remaining_tokens
|
||||
- 1
|
||||
] = True
|
||||
|
||||
cumulative_max_length += request_postfix_length + remaining_tokens - 1
|
||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
@ -567,17 +565,17 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = None
|
||||
slot_indices = None
|
||||
slots = None
|
||||
prefix_lengths_tensor = None
|
||||
postfix_lengths_tensor = None
|
||||
cache_lengths_tensor = None
|
||||
input_lengths_tensor = None
|
||||
adapter_meta = None
|
||||
else:
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
postfix_lengths_tensor = self.postfix_lengths_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
prefix_lengths_tensor = self.prefix_lengths_tensor[indices]
|
||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
@ -605,7 +603,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_input_length=max_input_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=self.prefilling,
|
||||
prefilling_mask=prefilling_mask,
|
||||
@ -615,10 +613,10 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
@ -647,7 +645,7 @@ class FlashCausalLMBatch(Batch):
|
||||
total_slots = 0
|
||||
max_blocks = 0
|
||||
max_length = 0
|
||||
max_postfix_length = 0
|
||||
max_input_length = 0
|
||||
max_current_length = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
@ -659,7 +657,7 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_length = (
|
||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
)
|
||||
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
|
||||
max_input_length = max(max_input_length, b.max_input_length)
|
||||
max_current_length = max(max_current_length, b.max_current_length)
|
||||
max_length = max(
|
||||
max_length,
|
||||
@ -680,8 +678,8 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = None
|
||||
slots = None
|
||||
slot_indices = None
|
||||
prefix_lengths_tensor = None
|
||||
postfix_lengths_tensor = None
|
||||
cache_lengths_tensor = None
|
||||
input_lengths_tensor = None
|
||||
adapter_meta = None
|
||||
adapter_segment_builder = None
|
||||
else:
|
||||
@ -689,10 +687,10 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty(
|
||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty(
|
||||
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
)
|
||||
total_indices_size = sum(
|
||||
@ -718,12 +716,12 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
|
||||
block_tables = []
|
||||
prefix_lengths = []
|
||||
cache_lengths = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
|
||||
prompt_lengths = []
|
||||
postfix_lengths = []
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
|
||||
@ -773,9 +771,7 @@ class FlashCausalLMBatch(Batch):
|
||||
slot_indices[start_index:end_index] = (
|
||||
batch.slot_indices + cumulative_slots
|
||||
)
|
||||
postfix_lengths_tensor[start_index:end_index] = (
|
||||
batch.postfix_lengths_tensor
|
||||
)
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
# Copy over adapter indices
|
||||
@ -793,9 +789,7 @@ class FlashCausalLMBatch(Batch):
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
prefix_lengths_tensor[start_index:end_index] = (
|
||||
batch.prefix_lengths_tensor
|
||||
)
|
||||
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
@ -806,12 +800,12 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
prefilling_mask.extend(batch.prefilling_mask)
|
||||
block_tables.extend(batch.block_tables)
|
||||
prefix_lengths.extend(batch.prefix_lengths)
|
||||
cache_lengths.extend(batch.cache_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
prefix_ids.extend(batch.prefix_ids)
|
||||
|
||||
prompt_lengths.extend(batch.prompt_lengths)
|
||||
postfix_lengths.extend(batch.postfix_lengths)
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
|
||||
@ -860,10 +854,10 @@ class FlashCausalLMBatch(Batch):
|
||||
slot_indices=slot_indices,
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
prefix_lengths=prefix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
cache_lengths=cache_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
slots=slots,
|
||||
max_postfix_length=max_postfix_length,
|
||||
max_input_length=max_input_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=prefilling,
|
||||
prefilling_mask=prefilling_mask,
|
||||
@ -873,8 +867,8 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_logprob_tokens=prefill_logprob_tokens,
|
||||
prompt_lengths=prompt_lengths,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
@ -918,30 +912,30 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
for i, (
|
||||
r,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
blocks,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.prefix_lengths,
|
||||
self.postfix_lengths,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
self.block_tables,
|
||||
)
|
||||
):
|
||||
next_chunk_length = postfix_length
|
||||
next_chunk_length = input_length
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
prefix_length, prefix_length + postfix_length, dtype=torch.int32
|
||||
cache_length, cache_length + input_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + postfix_length)
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
if not r.slots:
|
||||
request_slots = [
|
||||
@ -952,18 +946,18 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
request_slots = r.slots
|
||||
|
||||
request_slots = request_slots[prefix_length:]
|
||||
request_slots = request_slots[cache_length:]
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_slot_tokens,
|
||||
cumulative_slot_tokens + postfix_length,
|
||||
cumulative_slot_tokens + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, postfix_length - sliding_window),
|
||||
cumulative_length + postfix_length,
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
@ -976,16 +970,14 @@ class FlashCausalLMBatch(Batch):
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + postfix_length - 1
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(
|
||||
prefill_out_cumulative_length + postfix_length
|
||||
)
|
||||
prefill_out_cumulative_length += postfix_length
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + postfix_length - 1],
|
||||
[cumulative_length + input_length - 1],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
@ -1038,8 +1030,8 @@ class FlashCausalLMBatch(Batch):
|
||||
self.prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
self.postfix_lengths_tensor = torch.tensor(
|
||||
self.postfix_lengths, dtype=torch.int32, device=device
|
||||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
@ -1059,8 +1051,8 @@ class FlashCausalLMBatch(Batch):
|
||||
self.prefill_head_indices = prefill_head_indices
|
||||
self.prefill_next_token_indices = prefill_next_token_indices
|
||||
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
self.prefix_lengths_tensor = torch.tensor(
|
||||
self.prefix_lengths, dtype=torch.int32, device=device
|
||||
self.cache_lengths_tensor = torch.tensor(
|
||||
self.cache_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
@ -1276,12 +1268,12 @@ class FlashCausalLM(Model):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
postfix_lengths = [max_s] * bs
|
||||
prefix_lengths = [0] * bs
|
||||
postfix_lengths_tensor = (
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
@ -1290,8 +1282,8 @@ class FlashCausalLM(Model):
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
postfix_lengths=postfix_lengths,
|
||||
prefix_lengths=prefix_lengths,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
)
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
@ -1319,8 +1311,8 @@ class FlashCausalLM(Model):
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"postfix_lengths": postfix_lengths_tensor,
|
||||
"prefix_lengths": prefix_lengths_tensor,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"cache_lengths": cache_lengths_tensor,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
@ -1330,13 +1322,13 @@ class FlashCausalLM(Model):
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
postfix_lengths_tensor=postfix_lengths_tensor,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
postfix_lengths=postfix_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
@ -1359,8 +1351,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
postfix_lengths=postfix_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
@ -1517,8 +1509,8 @@ class FlashCausalLM(Model):
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
postfix_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
prefix_lengths_tensor = torch.zeros(
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
seqlen, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
@ -1526,8 +1518,8 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
postfix_lengths=postfix_lengths,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=1,
|
||||
max_k=seqlen,
|
||||
@ -1558,7 +1550,7 @@ class FlashCausalLM(Model):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -1575,11 +1567,11 @@ class FlashCausalLM(Model):
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
postfix_lengths = (
|
||||
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lengths_tensor = (
|
||||
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -1600,8 +1592,8 @@ class FlashCausalLM(Model):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -1623,19 +1615,19 @@ class FlashCausalLM(Model):
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
postfix_lengths_tensor=postfix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
|
||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
postfix_lengths=postfix_lengths,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
@ -1664,8 +1656,8 @@ class FlashCausalLM(Model):
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
@ -1678,18 +1670,18 @@ class FlashCausalLM(Model):
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["postfix_lengths"].zero_()
|
||||
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||
cuda_graph["prefix_lengths"].zero_()
|
||||
cuda_graph["prefix_lengths"][
|
||||
: prefix_lengths_tensor.shape[0]
|
||||
] = prefix_lengths_tensor
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
|
||||
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
@ -1775,13 +1767,13 @@ class FlashCausalLM(Model):
|
||||
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
|
||||
# We reverse to prioritize older requests
|
||||
# zip() is not reversible so reverse the underlying lists instead
|
||||
for prefix_length, postfix_length, prompt_length in zip(
|
||||
reversed(batch.prefix_lengths),
|
||||
reversed(batch.postfix_lengths),
|
||||
for cache_length, input_length, prompt_length in zip(
|
||||
reversed(batch.cache_lengths),
|
||||
reversed(batch.input_lengths),
|
||||
reversed(batch.prompt_lengths),
|
||||
):
|
||||
remaining_prefill_tokens = max(
|
||||
prompt_length - prefix_length - postfix_length, 0
|
||||
prompt_length - cache_length - input_length, 0
|
||||
)
|
||||
if remaining_prefill_tokens > 0:
|
||||
next_chunk_length = max(
|
||||
@ -1842,8 +1834,8 @@ class FlashCausalLM(Model):
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.prompt_lengths,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
batch.all_input_ids,
|
||||
accepted_ids,
|
||||
)
|
||||
@ -1858,14 +1850,14 @@ class FlashCausalLM(Model):
|
||||
cumulative_length = 0
|
||||
for i, (
|
||||
prompt_length,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
all_input_ids,
|
||||
n_accepted_ids,
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
end_index = cumulative_length + postfix_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
if prefill:
|
||||
# Indexing metadata
|
||||
@ -1899,17 +1891,17 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Represent whether this request is still prefilling
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
accept_tokens = prefix_length + postfix_length >= prompt_length
|
||||
accept_tokens = cache_length + input_length >= prompt_length
|
||||
|
||||
if accept_tokens:
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[
|
||||
i, prefix_length + postfix_length + j
|
||||
] = next_input_ids[index]
|
||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||
next_input_ids[index]
|
||||
)
|
||||
index += 1
|
||||
|
||||
cumulative_length += postfix_length
|
||||
cumulative_length += input_length
|
||||
|
||||
# Update values
|
||||
# These values can be updated without a GPU -> CPU sync
|
||||
@ -1917,8 +1909,8 @@ class FlashCausalLM(Model):
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.prefix_lengths_tensor += batch.postfix_lengths_tensor
|
||||
batch.postfix_lengths_tensor = accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = accepted_ids
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
@ -1959,24 +1951,24 @@ class FlashCausalLM(Model):
|
||||
request_prefilling,
|
||||
next_token_id,
|
||||
all_input_ids,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
next_chunk_length,
|
||||
) in enumerate(
|
||||
zip(
|
||||
batch.prefilling_mask,
|
||||
next_token_ids,
|
||||
batch.all_input_ids,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
next_chunk_lengths,
|
||||
)
|
||||
):
|
||||
if request_prefilling:
|
||||
next_prefix_length = prefix_length + postfix_length
|
||||
next_cache_length = cache_length + input_length
|
||||
# Get new prompt IDs to prefill
|
||||
postfix_ids = all_input_ids[
|
||||
next_prefix_length : next_prefix_length + next_chunk_length
|
||||
next_cache_length : next_cache_length + next_chunk_length
|
||||
]
|
||||
else:
|
||||
# This request is done prefilling, the new id is the one selected the sampling method
|
||||
@ -1996,8 +1988,8 @@ class FlashCausalLM(Model):
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.prompt_lengths,
|
||||
batch.prefix_lengths,
|
||||
batch.postfix_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
batch.stopping_criterias,
|
||||
@ -2012,15 +2004,15 @@ class FlashCausalLM(Model):
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# Reset max_postfix_length
|
||||
batch.max_postfix_length = 0
|
||||
# Reset max_input_length
|
||||
batch.max_input_length = 0
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
for i, (
|
||||
request,
|
||||
prompt_length,
|
||||
prefix_length,
|
||||
postfix_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
stopping_criteria,
|
||||
@ -2084,9 +2076,9 @@ class FlashCausalLM(Model):
|
||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
||||
# processing
|
||||
stopped = False
|
||||
new_postfix_length = next_chunk_lengths[i]
|
||||
new_input_length = next_chunk_lengths[i]
|
||||
else:
|
||||
new_postfix_length = n_accepted_ids
|
||||
new_input_length = n_accepted_ids
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
@ -2198,14 +2190,12 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Update values
|
||||
current_prefix_length = prefix_length + postfix_length
|
||||
batch.prefix_lengths[i] = current_prefix_length
|
||||
current_postfix_length = new_postfix_length
|
||||
batch.max_postfix_length = max(
|
||||
batch.max_postfix_length, current_postfix_length
|
||||
)
|
||||
batch.postfix_lengths[i] = current_postfix_length
|
||||
current_length = current_prefix_length + current_postfix_length
|
||||
current_cache_length = cache_length + input_length
|
||||
batch.cache_lengths[i] = current_cache_length
|
||||
current_input_length = new_input_length
|
||||
batch.max_input_length = max(batch.max_input_length, current_input_length)
|
||||
batch.input_lengths[i] = current_input_length
|
||||
current_length = current_cache_length + current_input_length
|
||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
@ -2235,8 +2225,8 @@ class FlashCausalLM(Model):
|
||||
*,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
postfix_lengths_tensor: torch.Tensor,
|
||||
prefix_lengths_tensor: torch.Tensor,
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
cache_lengths_tensor: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
) -> ContextManager:
|
||||
if ATTENTION != "flashinfer":
|
||||
@ -2247,7 +2237,7 @@ class FlashCausalLM(Model):
|
||||
use_prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
# has_prefix_lengths = any(prefix_length > 0 for prefix_length in prefix_lengths)
|
||||
# has_cache_lengths = any(cache_length > 0 for cache_length in cache_lengths)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
return use_prefill_with_paged_kv_state(
|
||||
@ -2256,12 +2246,12 @@ class FlashCausalLM(Model):
|
||||
),
|
||||
# block_tables=block_tables_to_ragged(
|
||||
# block_tables=block_tables,
|
||||
# postfix_lengths=postfix_lengths,
|
||||
# prefix_lengths=prefix_lengths,
|
||||
# input_lengths=input_lengths,
|
||||
# cache_lengths=cache_lengths,
|
||||
# ),
|
||||
block_tables=block_tables,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
@ -2270,10 +2260,10 @@ class FlashCausalLM(Model):
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
assert postfix_lengths_tensor is not None
|
||||
assert input_lengths_tensor is not None
|
||||
return use_decode_state(
|
||||
state=state if state is not None else self.decode_state,
|
||||
input_lengths=postfix_lengths_tensor + prefix_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||
block_tables=block_tables,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
@ -2285,21 +2275,19 @@ class FlashCausalLM(Model):
|
||||
|
||||
|
||||
def block_tables_to_ragged(
|
||||
*, block_tables: torch.Tensor, postfix_lengths: List[int], prefix_lengths: List[int]
|
||||
*, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
|
||||
) -> torch.Tensor:
|
||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||
assert len(postfix_lengths) == len(prefix_lengths)
|
||||
assert len(input_lengths) == len(cache_lengths)
|
||||
|
||||
total_len = sum(postfix_lengths) + sum(prefix_lengths)
|
||||
total_len = sum(input_lengths) + sum(cache_lengths)
|
||||
block_tables_ragged = torch.empty(
|
||||
total_len, dtype=torch.int32, device=block_tables.device
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for i, (input_length, prefix_length) in enumerate(
|
||||
zip(postfix_lengths, prefix_lengths)
|
||||
):
|
||||
seq_len = prefix_length + input_length
|
||||
for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
|
||||
seq_len = cache_length + input_length
|
||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||
offset += seq_len
|
||||
|
||||
|
@ -285,7 +285,7 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cache_lengths=prefix_lens_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
|
@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM):
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
postfix_lengths = (
|
||||
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lengths_tensor = (
|
||||
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
cache_lengths_tensor = (
|
||||
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -359,19 +359,19 @@ class VlmCausalLM(FlashCausalLM):
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
postfix_lengths_tensor=postfix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
|
||||
max_k = (input_lengths + cache_lengths_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
postfix_lengths=postfix_lengths,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
@ -408,8 +408,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
@ -418,18 +418,18 @@ class VlmCausalLM(FlashCausalLM):
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["postfix_lengths"].zero_()
|
||||
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||
cuda_graph["prefix_lengths"].zero_()
|
||||
cuda_graph["prefix_lengths"][
|
||||
: prefix_lengths_tensor.shape[0]
|
||||
] = prefix_lengths_tensor
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
|
||||
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
|
Loading…
Reference in New Issue
Block a user