mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
maybe patching vlms?
This commit is contained in:
parent
838756eb18
commit
e4f9110e14
@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
postfix_lengths = batch.postfix_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
).view(-1)
|
).view(-1)
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
input_lengths = (
|
postfix_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
prefix_lengths_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
postfix_lengths = batch.postfix_lengths_tensor
|
||||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -357,23 +357,23 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
input_lengths = postfix_lengths + prefix_lengths_tensor
|
||||||
if PREFIX_CACHING:
|
if PREFIX_CACHING:
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
postfix_lengths=batch.postfix_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lengths=batch.prefix_lengths,
|
||||||
)
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
postfix_lengths_tensor=postfix_lengths,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
postfix_lengths=postfix_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=max_s,
|
max_q=max_s,
|
||||||
max_k=max_k,
|
max_k=max_k,
|
||||||
@ -410,8 +410,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=batch.input_lengths,
|
postfix_lengths=batch.postfix_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lengths=batch.prefix_lengths,
|
||||||
)
|
)
|
||||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
else:
|
else:
|
||||||
@ -420,11 +420,20 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["postfix_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||||
input_lengths + prefix_lens_tensor
|
cuda_graph["prefix_lengths"].zero_()
|
||||||
)
|
cuda_graph["prefix_lengths"][
|
||||||
|
: prefix_lengths_tensor.shape[0]
|
||||||
|
] = prefix_lengths_tensor
|
||||||
|
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
|
||||||
|
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user