mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixup:
- Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs.
This commit is contained in:
parent
97c504136c
commit
95155a212b
@ -316,10 +316,15 @@ impl State {
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator
|
||||
.allocate(tokens, entry.request.input_ids.clone())
|
||||
.await
|
||||
{
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
None
|
||||
} else {
|
||||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
|
@ -468,7 +468,6 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
|
||||
log_master(logger.info, f"Block tables {block_tables}")
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
@ -1915,7 +1914,6 @@ class FlashCausalLM(Model):
|
||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
log_master(logger.info, f"Prefix lens {prefix_lens}")
|
||||
return use_prefill_with_paged_kv_state(
|
||||
state=(
|
||||
state if state is not None else self.prefill_with_paged_kv_state
|
||||
|
@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
block_tables_to_ragged,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||
trust_remote_code: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if PREFIX_CACHING:
|
||||
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
self.processor = processor_class.from_pretrained(
|
||||
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
).reshape(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
Loading…
Reference in New Issue
Block a user