mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing prefix caching.
This commit is contained in:
parent
b2933b72d0
commit
99b6b5c795
@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
|
||||
let prefix_len = blocks.len();
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
|
||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
|
@ -76,7 +76,7 @@ def paged_attention(
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||
|
||||
return decode_state.get().forward(
|
||||
query.contiguous(),
|
||||
@ -234,7 +234,7 @@ if ATTENTION == "flashinfer":
|
||||
softcap=0.0,
|
||||
):
|
||||
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
|
@ -468,6 +468,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
|
||||
log_master(logger.info, f"Block tables {block_tables}")
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
@ -1028,7 +1029,7 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_prefill_state,
|
||||
create_decode_state,
|
||||
create_prefill_with_paged_kv_state,
|
||||
@ -1166,7 +1167,7 @@ class FlashCausalLM(Model):
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
@ -1411,7 +1412,6 @@ class FlashCausalLM(Model):
|
||||
).view(-1)
|
||||
prefix_lens_tensor = (
|
||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||
+ arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
@ -1452,6 +1452,13 @@ class FlashCausalLM(Model):
|
||||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
@ -1496,9 +1503,9 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][
|
||||
: input_lengths.shape[0]
|
||||
] = input_lengths # + prefix_lens_tensor
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
@ -1736,7 +1743,7 @@ class FlashCausalLM(Model):
|
||||
left = 0
|
||||
|
||||
if n_accepted_ids > 1:
|
||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||
log_master(logger.info, f"Speculated ids {n_accepted_ids - 1}")
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
@ -1900,40 +1907,32 @@ class FlashCausalLM(Model):
|
||||
if ATTENTION != "flashinfer":
|
||||
return nullcontext()
|
||||
|
||||
from text_generation_server.layers.attention.flash_infer import (
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
use_decode_state,
|
||||
use_prefill_state,
|
||||
use_prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
if True: # has_prefix_lens:
|
||||
return use_prefill_with_paged_kv_state(
|
||||
state=(
|
||||
state if state is not None else self.prefill_with_paged_kv_state
|
||||
),
|
||||
block_tables=block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
prefix_lens=prefix_lens,
|
||||
),
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
return use_prefill_state(
|
||||
state=state if state is not None else self.prefill_state,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
)
|
||||
log_master(logger.info, f"Prefix lens {prefix_lens}")
|
||||
return use_prefill_with_paged_kv_state(
|
||||
state=(
|
||||
state if state is not None else self.prefill_with_paged_kv_state
|
||||
),
|
||||
# block_tables=block_tables_to_ragged(
|
||||
# block_tables=block_tables,
|
||||
# input_lengths=input_lengths,
|
||||
# prefix_lens=prefix_lens,
|
||||
# ),
|
||||
block_tables=block_tables,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
page_size=BLOCK_SIZE,
|
||||
)
|
||||
else:
|
||||
assert input_lengths_tensor is not None
|
||||
return use_decode_state(
|
||||
|
Loading…
Reference in New Issue
Block a user