mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Fixing prefix caching.
This commit is contained in:
parent
b2933b72d0
commit
99b6b5c795
@ -89,6 +89,8 @@ impl Allocator for RadixAllocator {
|
|||||||
let prefix_len = blocks.len();
|
let prefix_len = blocks.len();
|
||||||
let suffix_len = tokens - prefix_len as u32;
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
|
@ -76,7 +76,7 @@ def paged_attention(
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
@ -234,7 +234,7 @@ if ATTENTION == "flashinfer":
|
|||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
prefill_with_paged_kv_state,
|
prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -468,6 +468,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
log_master(logger.info, f"Block tables {block_tables}")
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -1028,7 +1029,7 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_prefill_state,
|
create_prefill_state,
|
||||||
create_decode_state,
|
create_decode_state,
|
||||||
create_prefill_with_paged_kv_state,
|
create_prefill_with_paged_kv_state,
|
||||||
@ -1166,7 +1167,7 @@ class FlashCausalLM(Model):
|
|||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1411,7 +1412,6 @@ class FlashCausalLM(Model):
|
|||||||
).view(-1)
|
).view(-1)
|
||||||
prefix_lens_tensor = (
|
prefix_lens_tensor = (
|
||||||
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
+ arange_int
|
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
@ -1452,6 +1452,13 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -1496,9 +1503,9 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
: input_lengths.shape[0]
|
input_lengths + prefix_lens_tensor
|
||||||
] = input_lengths # + prefix_lens_tensor
|
)
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
@ -1736,7 +1743,7 @@ class FlashCausalLM(Model):
|
|||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if n_accepted_ids > 1:
|
||||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
log_master(logger.info, f"Speculated ids {n_accepted_ids - 1}")
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
@ -1900,25 +1907,25 @@ class FlashCausalLM(Model):
|
|||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
use_decode_state,
|
use_decode_state,
|
||||||
use_prefill_state,
|
|
||||||
use_prefill_with_paged_kv_state,
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
if True: # has_prefix_lens:
|
log_master(logger.info, f"Prefix lens {prefix_lens}")
|
||||||
return use_prefill_with_paged_kv_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
state=(
|
state=(
|
||||||
state if state is not None else self.prefill_with_paged_kv_state
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
),
|
),
|
||||||
block_tables=block_tables_to_ragged(
|
# block_tables=block_tables_to_ragged(
|
||||||
|
# block_tables=block_tables,
|
||||||
|
# input_lengths=input_lengths,
|
||||||
|
# prefix_lens=prefix_lens,
|
||||||
|
# ),
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
input_lengths=input_lengths,
|
|
||||||
prefix_lens=prefix_lens,
|
|
||||||
),
|
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths_tensor,
|
input_lengths=input_lengths_tensor,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -1926,14 +1933,6 @@ class FlashCausalLM(Model):
|
|||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return use_prefill_state(
|
|
||||||
state=state if state is not None else self.prefill_state,
|
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert input_lengths_tensor is not None
|
assert input_lengths_tensor is not None
|
||||||
return use_decode_state(
|
return use_decode_state(
|
||||||
|
Loading…
Reference in New Issue
Block a user