diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7f7d2e4d..937811d7 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -294,7 +294,7 @@ class VlmCausalLM(FlashCausalLM): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor + postfix_lengths = batch.postfix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -311,11 +311,11 @@ class VlmCausalLM(FlashCausalLM): position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + postfix_lengths = ( + postfix_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) - prefix_lens_tensor = ( - batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + prefix_lengths_tensor = ( + batch.prefix_lengths_tensor.unsqueeze(-1).expand(B, new_length) ).reshape(-1) # Add Copy the block tables for all members @@ -336,8 +336,8 @@ class VlmCausalLM(FlashCausalLM): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - prefix_lens_tensor = batch.prefix_lens_tensor + postfix_lengths = batch.postfix_lengths_tensor + prefix_lengths_tensor = batch.prefix_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -357,23 +357,23 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor + input_lengths = postfix_lengths + prefix_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - prefix_lens_tensor=prefix_lens_tensor, + postfix_lengths_tensor=postfix_lengths, + prefix_lengths_tensor=prefix_lengths_tensor, ): - max_k = (input_lengths + prefix_lens_tensor).max().item() + max_k = (postfix_lengths + prefix_lengths_tensor).max().item() seqlen = Seqlen( - input_lengths=input_lengths, - prefix_lengths=prefix_lens_tensor, + postfix_lengths=postfix_lengths, + prefix_lengths=prefix_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=max_s, max_k=max_k, @@ -410,8 +410,8 @@ class VlmCausalLM(FlashCausalLM): if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, - input_lengths=batch.input_lengths, - prefix_lens=batch.prefix_lens, + postfix_lengths=batch.postfix_lengths, + prefix_lengths=batch.prefix_lengths, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: @@ -420,13 +420,22 @@ class VlmCausalLM(FlashCausalLM): ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + prefix_lens_tensor - ) + cuda_graph["postfix_lengths"].zero_() + cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths + cuda_graph["prefix_lengths"].zero_() + cuda_graph["prefix_lengths"][ + : prefix_lengths_tensor.shape[0] + ] = prefix_lengths_tensor - # Replay the graph - cuda_graph["graph"].replay() + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + postfix_lengths_tensor=cuda_graph["postfix_lengths"], + prefix_lengths_tensor=cuda_graph["prefix_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() # Slice output to the correct shape speculative_logits = (