text-generation-inference/server/text_generation_server/models/globals.py
Daniël de Kok 8deeaca4ff
Add support for prefix caching to the v3 router (#2392)
This change adds support for prefix caching to the v3 router. This
is broken up from the backend support to ease reviewing.

For now prefix caching is only enabled with `USE_PREFIX_CACHING=1`
in this case, the router will switch to `RadixAllocator`. This
allocator uses a radix trie to keep track of prefills that were
seen prior. If a new prefill is a prefix of a previously-seen
prefil, the router will send a request with `prefix_len>0`, which
can be used by the backend to decide to reuse KV blocks from the
cache, rather than recomputing them.

Even though backend support is not added in this PR, the backend
will still work with prefix caching enabled. The prefix lengths
are just ignored and not used.
2024-08-12 14:59:17 +02:00

63 lines
1.8 KiB
Python

import torch
import os
from loguru import logger
from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
cuda_graphs.sort(reverse=True)
CUDA_GRAPHS = cuda_graphs
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index
def get_adapter_to_index():
global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX