mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
This change adds support for prefix caching to the v3 router. This is broken up from the backend support to ease reviewing. For now prefix caching is only enabled with `USE_PREFIX_CACHING=1` in this case, the router will switch to `RadixAllocator`. This allocator uses a radix trie to keep track of prefills that were seen prior. If a new prefill is a prefix of a previously-seen prefil, the router will send a request with `prefix_len>0`, which can be used by the backend to decide to reuse KV blocks from the cache, rather than recomputing them. Even though backend support is not added in this PR, the backend will still work with prefix caching enabled. The prefix lengths are just ignored and not used.
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
import torch
|
|
import os
|
|
from loguru import logger
|
|
from typing import Dict, Optional
|
|
|
|
from text_generation_server.utils.log import log_master
|
|
|
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
|
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
|
|
|
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
|
assert (
|
|
ATTENTION in _expected
|
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
|
|
|
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
|
|
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
|
|
|
# This is overridden by the cli
|
|
BLOCK_SIZE: int
|
|
if ATTENTION == "flashdecoding":
|
|
BLOCK_SIZE = 256
|
|
elif ATTENTION == "flashinfer":
|
|
BLOCK_SIZE = 1
|
|
else:
|
|
BLOCK_SIZE = 16
|
|
|
|
|
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
|
if cuda_graphs is not None:
|
|
try:
|
|
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
|
)
|
|
else:
|
|
cuda_graphs = None
|
|
# sorting the cuda graphs in descending order helps reduce the
|
|
# memory impact and results in less memory usage
|
|
if cuda_graphs is not None:
|
|
cuda_graphs.sort(reverse=True)
|
|
|
|
CUDA_GRAPHS = cuda_graphs
|
|
|
|
# NOTE: eventually we should move this into the router and pass back the
|
|
# index in all cases.
|
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
|
|
|
|
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
|
global ADAPTER_TO_INDEX
|
|
ADAPTER_TO_INDEX = adapter_to_index
|
|
|
|
|
|
def get_adapter_to_index():
|
|
global ADAPTER_TO_INDEX
|
|
return ADAPTER_TO_INDEX
|