maybe patching vlms?

This commit is contained in:
OlivierDehaene 2024-09-25 14:54:59 +02:00
parent 838756eb18
commit e4f9110e14
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM):
position_ids.unsqueeze(-1).expand(B, new_length) + arange position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1) ).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = ( postfix_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( prefix_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor postfix_lengths = batch.postfix_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -357,23 +357,23 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor input_lengths = postfix_lengths + prefix_lengths_tensor
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, postfix_lengths=batch.postfix_lengths,
prefix_lens=batch.prefix_lens, prefix_lengths=batch.prefix_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, postfix_lengths_tensor=postfix_lengths,
prefix_lens_tensor=prefix_lens_tensor, prefix_lengths_tensor=prefix_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item() max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, postfix_lengths=postfix_lengths,
prefix_lengths=prefix_lens_tensor, prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=max_s,
max_k=max_k, max_k=max_k,
@ -410,8 +410,8 @@ class VlmCausalLM(FlashCausalLM):
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, postfix_lengths=batch.postfix_lengths,
prefix_lens=batch.prefix_lens, prefix_lengths=batch.prefix_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
@ -420,13 +420,22 @@ class VlmCausalLM(FlashCausalLM):
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["postfix_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
input_lengths + prefix_lens_tensor cuda_graph["prefix_lengths"].zero_()
) cuda_graph["prefix_lengths"][
: prefix_lengths_tensor.shape[0]
] = prefix_lengths_tensor
# Replay the graph with self._forward_context(
cuda_graph["graph"].replay() block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = ( speculative_logits = (