fixes on review

This commit is contained in:
fxmarty 2024-05-17 07:53:27 +00:00
parent c9455730d7
commit df0a453693
2 changed files with 8 additions and 5 deletions

View File

@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
@ -836,10 +837,10 @@ class FlashCausalLM(Model):
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
] ]
else: else:
tuning_sequences = [1, 2, 4, 8, 16, 32] tuning_sequences = CUDA_GRAPHS
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
"/data", HUGGINGFACE_HUB_CACHE,
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
@ -853,7 +854,7 @@ class FlashCausalLM(Model):
) )
torch.cuda.tunable.read_file(tunableop_filepath) torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs("/data", exist_ok=True) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences: for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}") logger.info(f"Warming up TunableOp for seqlen={seqlen}")

View File

@ -173,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -194,7 +195,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
None, None,
) )
@ -210,6 +211,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
output, _ = triton_attention( output, _ = triton_attention(
q, q,
@ -220,7 +222,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
cu_seqlens, cu_seqlens,
max_s, max_s,
max_s, max_s,
True, causal,
softmax_scale, softmax_scale,
) )
return output return output