mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
rename
This commit is contained in:
parent
fc41f0784a
commit
594a2b4a3d
@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||
ENV VLLM_MOE_PADDING=0
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV ROCM_USE_SKINNY_GEMM=1
|
||||
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
|
@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
@ -406,6 +406,7 @@ def launcher(event_loop):
|
||||
print(" ".join(args), file=sys.stderr)
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
env["PREFILL_CHUNKING"] = "1"
|
||||
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
@ -504,6 +505,7 @@ def launcher(event_loop):
|
||||
|
||||
env = {
|
||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||
"PREFILL_CHUNKING": "1",
|
||||
}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
@ -68,7 +68,7 @@ fn get_config(
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = gpu::get_cuda_capability();
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
if prefix_caching.is_none() {
|
||||
@ -1678,7 +1678,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
};
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("ATTENTION", attention);
|
||||
|
||||
let max_input_tokens = {
|
||||
|
@ -11,6 +11,7 @@ PREFIX_CACHING = os.getenv("PREFIX_CACHING", default_prefix_caching).lower() in
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
|
@ -7,7 +7,12 @@ from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models.globals import ATTENTION, PREFIX_CACHING, BLOCK_SIZE
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
PREFIX_CACHING,
|
||||
BLOCK_SIZE,
|
||||
PREFILL_CHUNKING,
|
||||
)
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||
@ -65,6 +70,8 @@ class Model(ABC):
|
||||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
||||
support_chunking = support_chunking and PREFILL_CHUNKING
|
||||
|
||||
if speculate != 0 and support_chunking:
|
||||
log_master(
|
||||
logger.warning,
|
||||
@ -79,6 +86,10 @@ class Model(ABC):
|
||||
)
|
||||
support_chunking = False
|
||||
|
||||
log_master(
|
||||
logger.info, f"Using experimental prefill chunking = {support_chunking}"
|
||||
)
|
||||
|
||||
self.support_chunking = support_chunking
|
||||
set_support_chunking(support_chunking)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user