diff --git a/Dockerfile_amd b/Dockerfile_amd index 4bb6407a..b84d4edd 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV VLLM_MOE_PADDING=0 ENV ATTENTION=paged -ENV USE_PREFIX_CACHING=0 +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 ENV ROCM_USE_SKINNY_GEMM=1 COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh diff --git a/Dockerfile_intel b/Dockerfile_intel index 9b5dd20a..96f24248 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo FROM ${PLATFORM} AS final ENV ATTENTION=paged -ENV USE_PREFIX_CACHING=0 +ENV PREFIX_CACHING=0 +ENV PREFILL_CHUNKING=0 ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 8f312942..356fa5e3 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -406,6 +406,7 @@ def launcher(event_loop): print(" ".join(args), file=sys.stderr) env["LOG_LEVEL"] = "info,text_generation_router=debug" + env["PREFILL_CHUNKING"] = "1" if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" @@ -504,6 +505,7 @@ def launcher(event_loop): env = { "LOG_LEVEL": "info,text_generation_router=debug", + "PREFILL_CHUNKING": "1", } if not use_flash_attention: env["USE_FLASH_ATTENTION"] = "false" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e6ef6c2d..d2f1c0e3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -68,7 +68,7 @@ fn get_config( fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { let compute_capability = gpu::get_cuda_capability(); - let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); + let mut prefix_caching: Option = std::env::var("PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { if prefix_caching.is_none() { @@ -1678,7 +1678,7 @@ fn main() -> Result<(), LauncherError> { }; let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); - std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("PREFIX_CACHING", prefix_caching); std::env::set_var("ATTENTION", attention); let max_input_tokens = { diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 0b60549a..8be92fbf 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -11,6 +11,7 @@ PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in "1", "true", } +PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 1da6e3e3..b3630013 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -7,7 +7,12 @@ from collections import defaultdict from transformers import PreTrainedTokenizerBase from loguru import logger -from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE +from text_generation_server.models.globals import ( + ATTENTION, + PREFIX_CACHING, + BLOCK_SIZE, + PREFILL_CHUNKING, +) from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.log import log_master from text_generation_server.utils.prefill_chunking import set_support_chunking @@ -65,6 +70,8 @@ class Model(ABC): speculate = get_speculate() self.speculate = speculate + support_chunking = support_chunking and PREFILL_CHUNKING + if speculate != 0 and support_chunking: log_master( logger.warning, @@ -79,6 +86,10 @@ class Model(ABC): ) support_chunking = False + log_master( + logger.info, f"Using experimental prefill chunking = {support_chunking}" + ) + self.support_chunking = support_chunking set_support_chunking(support_chunking)