Set default value of ATTENTION as paged

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-06-10 07:42:11 +00:00
parent c065c58818
commit 07a0e2f7e6
3 changed files with 4 additions and 8 deletions

View File

@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=default ENV ATTENTION=paged
ENV PREFIX_CACHING=0 ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0 ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1 ENV PT_HPU_LAZY_MODE=1

View File

@ -35,14 +35,10 @@ __all__ = [
"Seq2SeqLM", "Seq2SeqLM",
"get_model_with_lora_adapters", "get_model_with_lora_adapters",
] ]
from text_generation_server.models.globals import ATTENTION
VLM_BATCH_TYPES = set() VLM_BATCH_TYPES = set()
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False FLASH_ATTENTION = True
if ATTENTION == "paged":
FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM

View File

@ -4,14 +4,14 @@ from loguru import logger
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.getenv("ATTENTION", "default") ATTENTION = os.getenv("ATTENTION", "paged")
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
"1", "1",
"true", "true",
} }
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "default"} _expected = {"paged"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"