diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 0fb05a98..faa57c11 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -316,10 +316,15 @@ impl State { + self.speculate - 1; - match block_allocator - .allocate(tokens, entry.request.input_ids.clone()) - .await - { + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + match block_allocator.allocate(tokens, input_ids).await { None => { // Entry is over budget // Add it back to the front diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3285a527..dd4203e0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -468,7 +468,6 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = block_tables_tensor.to(device) prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) - log_master(logger.info, f"Block tables {block_tables}") return cls( batch_id=pb.id, requests=pb.requests, @@ -1915,7 +1914,6 @@ class FlashCausalLM(Model): # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: - log_master(logger.info, f"Prefix lens {prefix_lens}") return use_prefill_with_paged_kv_state( state=( state if state is not None else self.prefill_with_paged_kv_state diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 7de54aa4..2ed1a119 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, + block_tables_to_ragged, ) +from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen @@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): trust_remote_code: bool, **kwargs, ): + if PREFIX_CACHING: + raise NotImplementedError("Vlm do not work with prefix caching yet") if processor_kwargs is None: processor_kwargs = {} self.processor = processor_class.from_pretrained( @@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + ).reshape(-1) # Add Copy the block tables for all members block_tables = ( @@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + input_lengths = input_lengths + prefix_lens_tensor + if PREFIX_CACHING: + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - pixel_attention_mask=batch.pixel_attention_mask, - image_sizes=batch.image_sizes, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths=batch.input_lengths, + input_lengths_tensor=input_lengths, + prefix_lens=batch.prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + pixel_attention_mask=batch.pixel_attention_mask, + image_sizes=batch.image_sizes, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + prefix_lens=batch.prefix_lens, + ) + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables cuda_graph["slots"].fill_(-1) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( + input_lengths + prefix_lens_tensor + ) # Replay the graph cuda_graph["graph"].replay()