diff --git a/Dockerfile_gaudi b/Dockerfile_gaudi index 442eb6b7..02885405 100644 --- a/Dockerfile_gaudi +++ b/Dockerfile_gaudi @@ -57,7 +57,7 @@ ARG PYTORCH_VERSION FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base -ENV ATTENTION=default +ENV ATTENTION=paged ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV PT_HPU_LAZY_MODE=1 diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index f86be127..1e5ed0eb 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -35,14 +35,10 @@ __all__ = [ "Seq2SeqLM", "get_model_with_lora_adapters", ] -from text_generation_server.models.globals import ATTENTION VLM_BATCH_TYPES = set() -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = False -if ATTENTION == "paged": - FLASH_ATTENTION = True +FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM diff --git a/backends/gaudi/server/text_generation_server/models/globals.py b/backends/gaudi/server/text_generation_server/models/globals.py index cd221e14..cdde67ca 100644 --- a/backends/gaudi/server/text_generation_server/models/globals.py +++ b/backends/gaudi/server/text_generation_server/models/globals.py @@ -4,14 +4,14 @@ from loguru import logger from text_generation_server.utils.log import log_master REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} -ATTENTION = os.getenv("ATTENTION", "default") +ATTENTION = os.getenv("ATTENTION", "paged") # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { "1", "true", } log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "default"} +_expected = {"paged"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}"