mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Set default value of ATTENTION as paged
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
c065c58818
commit
07a0e2f7e6
@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
|
|||||||
|
|
||||||
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
|
||||||
|
|
||||||
ENV ATTENTION=default
|
ENV ATTENTION=paged
|
||||||
ENV PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
ENV PREFILL_CHUNKING=0
|
ENV PREFILL_CHUNKING=0
|
||||||
ENV PT_HPU_LAZY_MODE=1
|
ENV PT_HPU_LAZY_MODE=1
|
||||||
|
@ -35,13 +35,9 @@ __all__ = [
|
|||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model_with_lora_adapters",
|
"get_model_with_lora_adapters",
|
||||||
]
|
]
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
|
|
||||||
VLM_BATCH_TYPES = set()
|
VLM_BATCH_TYPES = set()
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|
||||||
|
|
||||||
FLASH_ATTENTION = False
|
|
||||||
if ATTENTION == "paged":
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4,14 +4,14 @@ from loguru import logger
|
|||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
ATTENTION = os.getenv("ATTENTION", "default")
|
ATTENTION = os.getenv("ATTENTION", "paged")
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
||||||
"1",
|
"1",
|
||||||
"true",
|
"true",
|
||||||
}
|
}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "default"}
|
_expected = {"paged"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
|
Loading…
Reference in New Issue
Block a user