mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fixes on review
This commit is contained in:
parent
c9455730d7
commit
df0a453693
@ -13,6 +13,7 @@ from opentelemetry import trace
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
@ -836,10 +837,10 @@ class FlashCausalLM(Model):
|
|||||||
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
tuning_sequences = CUDA_GRAPHS
|
||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
"/data",
|
HUGGINGFACE_HUB_CACHE,
|
||||||
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -853,7 +854,7 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
os.makedirs("/data", exist_ok=True)
|
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
|
@ -173,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
@ -194,7 +195,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -210,6 +211,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
):
|
):
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
@ -220,7 +222,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
True,
|
causal,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
Loading…
Reference in New Issue
Block a user