mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fixes on review
This commit is contained in:
parent
c9455730d7
commit
df0a453693
@ -13,6 +13,7 @@ from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
@ -836,10 +837,10 @@ class FlashCausalLM(Model):
|
||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||
]
|
||||
else:
|
||||
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
||||
tuning_sequences = CUDA_GRAPHS
|
||||
|
||||
tunableop_filepath = os.path.join(
|
||||
"/data",
|
||||
HUGGINGFACE_HUB_CACHE,
|
||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||
)
|
||||
|
||||
@ -853,7 +854,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||
|
||||
os.makedirs("/data", exist_ok=True)
|
||||
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||
|
||||
for seqlen in tuning_sequences:
|
||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||
|
@ -173,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -194,7 +195,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
@ -210,6 +211,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
@ -220,7 +222,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
True,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
return output
|
||||
|
Loading…
Reference in New Issue
Block a user