- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.
This commit is contained in:
Nicolas Patry 2024-08-14 18:26:29 +02:00
parent 97c504136c
commit 95155a212b
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 70 additions and 34 deletions

View File

@ -316,10 +316,15 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator // If users wants the prefill logprobs, we cannot reuse the cache.
.allocate(tokens, entry.request.input_ids.clone()) // So no input_ids for the radix tree.
.await let input_ids = if entry.request.decoder_input_details {
{ None
} else {
entry.request.input_ids.clone()
};
match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front

View File

@ -468,7 +468,6 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
log_master(logger.info, f"Block tables {block_tables}")
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -1915,7 +1914,6 @@ class FlashCausalLM(Model):
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
log_master(logger.info, f"Prefix lens {prefix_lens}")
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
state=( state=(
state if state is not None else self.prefill_with_paged_kv_state state if state is not None else self.prefill_with_paged_kv_state

View File

@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
block_tables_to_ragged,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
trust_remote_code: bool, trust_remote_code: bool,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None: if processor_kwargs is None:
processor_kwargs = {} processor_kwargs = {}
self.processor = processor_class.from_pretrained( self.processor = processor_class.from_pretrained(
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = ( block_tables = (
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()