From df0a453693379487c84cbe0961a97cdb8b31ba84 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 17 May 2024 07:53:27 +0000 Subject: [PATCH] fixes on review --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++--- server/text_generation_server/utils/flash_attn.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 92d8aa5c..b0ac9ece 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -13,6 +13,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens @@ -836,10 +837,10 @@ class FlashCausalLM(Model): for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] else: - tuning_sequences = [1, 2, 4, 8, 16, 32] + tuning_sequences = CUDA_GRAPHS tunableop_filepath = os.path.join( - "/data", + HUGGINGFACE_HUB_CACHE, f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) @@ -853,7 +854,7 @@ class FlashCausalLM(Model): ) torch.cuda.tunable.read_file(tunableop_filepath) - os.makedirs("/data", exist_ok=True) + os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: logger.info(f"Warming up TunableOp for seqlen={seqlen}") diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index c5fd7830..9ac5655c 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -173,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: max_s, softmax_scale, window_size_left=-1, + causal=True, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -194,7 +195,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK: 0.0, softmax_scale, False, - True, + causal, False, None, ) @@ -210,6 +211,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: max_s, softmax_scale, window_size_left=-1, + causal=True, ): output, _ = triton_attention( q, @@ -220,7 +222,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON: cu_seqlens, max_s, max_s, - True, + causal, softmax_scale, ) return output