This commit is contained in:
OlivierDehaene 2024-10-16 10:23:21 +02:00
parent fc41f0784a
commit 594a2b4a3d
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
6 changed files with 21 additions and 5 deletions

View File

@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh

View File

@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -406,6 +406,7 @@ def launcher(event_loop):
print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug"
env["PREFILL_CHUNKING"] = "1"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
@ -504,6 +505,7 @@ def launcher(event_loop):
env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"PREFILL_CHUNKING": "1",
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"

View File

@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
@ -1678,7 +1678,7 @@ fn main() -> Result<(), LauncherError> {
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {

View File

@ -11,6 +11,7 @@ PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (

View File

@ -7,7 +7,12 @@ from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
@ -65,6 +70,8 @@ class Model(ABC):
speculate = get_speculate()
self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
@ -79,6 +86,10 @@ class Model(ABC):
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)