mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Fixup:
- Remove logs - Disable VLMs (they do not work) - Disable prefix caching when user wants prefill logprobs.
This commit is contained in:
parent
97c504136c
commit
95155a212b
@ -316,10 +316,15 @@ impl State {
|
|||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
.allocate(tokens, entry.request.input_ids.clone())
|
// So no input_ids for the radix tree.
|
||||||
.await
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
{
|
None
|
||||||
|
} else {
|
||||||
|
entry.request.input_ids.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
None => {
|
None => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
|
@ -468,7 +468,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
log_master(logger.info, f"Block tables {block_tables}")
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -1915,7 +1914,6 @@ class FlashCausalLM(Model):
|
|||||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
log_master(logger.info, f"Prefix lens {prefix_lens}")
|
|
||||||
return use_prefill_with_paged_kv_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
state=(
|
state=(
|
||||||
state if state is not None else self.prefill_with_paged_kv_state
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
|
@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
block_tables_to_ragged,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
Loading…
Reference in New Issue
Block a user