fixes on review

This commit is contained in:
fxmarty 2024-05-17 07:53:27 +00:00
parent c9455730d7
commit df0a453693
2 changed files with 8 additions and 5 deletions

View File

@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
@ -836,10 +837,10 @@ class FlashCausalLM(Model):
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
tuning_sequences = [1, 2, 4, 8, 16, 32]
tuning_sequences = CUDA_GRAPHS
tunableop_filepath = os.path.join(
"/data",
HUGGINGFACE_HUB_CACHE,
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
@ -853,7 +854,7 @@ class FlashCausalLM(Model):
)
torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs("/data", exist_ok=True)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}")

View File

@ -173,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
@ -194,7 +195,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
0.0,
softmax_scale,
False,
True,
causal,
False,
None,
)
@ -210,6 +211,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
output, _ = triton_attention(
q,
@ -220,7 +222,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
cu_seqlens,
max_s,
max_s,
True,
causal,
softmax_scale,
)
return output