mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
Set default value of ATTENTION as paged
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
c065c58818
commit
07a0e2f7e6
@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
|
||||
|
||||
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||
|
||||
ENV ATTENTION=default
|
||||
ENV ATTENTION=paged
|
||||
ENV PREFIX_CACHING=0
|
||||
ENV PREFILL_CHUNKING=0
|
||||
ENV PT_HPU_LAZY_MODE=1
|
||||
|
@ -35,13 +35,9 @@ __all__ = [
|
||||
"Seq2SeqLM",
|
||||
"get_model_with_lora_adapters",
|
||||
]
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
VLM_BATCH_TYPES = set()
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
if ATTENTION == "paged":
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
|
@ -4,14 +4,14 @@ from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||
ATTENTION = os.getenv("ATTENTION", "default")
|
||||
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "default"}
|
||||
_expected = {"paged"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
|
Loading…
Reference in New Issue
Block a user