mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
maybe patching vlms?
This commit is contained in:
parent
838756eb18
commit
e4f9110e14
@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM):
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
postfix_lengths = (
|
||||
postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
prefix_lengths_tensor = (
|
||||
batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
postfix_lengths = batch.postfix_lengths_tensor
|
||||
prefix_lengths_tensor = batch.prefix_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -357,23 +357,23 @@ class VlmCausalLM(FlashCausalLM):
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
input_lengths = postfix_lengths + prefix_lengths_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
postfix_lengths_tensor=postfix_lengths,
|
||||
prefix_lengths_tensor=prefix_lengths_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
max_k = (postfix_lengths + prefix_lengths_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
postfix_lengths=postfix_lengths,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
@ -410,8 +410,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
postfix_lengths=batch.postfix_lengths,
|
||||
prefix_lengths=batch.prefix_lengths,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
@ -420,11 +420,20 @@ class VlmCausalLM(FlashCausalLM):
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["postfix_lengths"].zero_()
|
||||
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||
cuda_graph["prefix_lengths"].zero_()
|
||||
cuda_graph["prefix_lengths"][
|
||||
: prefix_lengths_tensor.shape[0]
|
||||
] = prefix_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
postfix_lengths_tensor=cuda_graph["postfix_lengths"],
|
||||
prefix_lengths_tensor=cuda_graph["prefix_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user