mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
apply suggestions
This commit is contained in:
parent
b7e98ba635
commit
f8d37c14d9
@ -117,13 +117,15 @@ RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
ARG GITHUB_TOKEN
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6.2.41130
|
||||
wget --header "Authorization: token ${GITHUB_TOKEN}" https://raw.githubusercontent.com/fxmarty/patched_hipruntime/main/libamdhip64.so.6
|
||||
|
||||
ENV LD_PRELOAD="/libamdhip64.so.6.2.41130"
|
||||
ENV LD_PRELOAD="/libamdhip64.so.6"
|
||||
|
||||
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
# Disabled for now as it is currently not stable with ROCm 6.1.
|
||||
# ENV HIP_FORCE_DEV_KERNARG=1
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI300, performances for flash with Triton FA is very competitive (actually better than CK)
|
||||
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=1
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
|
@ -813,14 +813,15 @@ class FlashCausalLM(Model):
|
||||
self.device,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
if SYSTEM == "rocm":
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
||||
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
|
||||
else:
|
||||
tuning_sequences = list(range(1, 8))
|
||||
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
||||
|
||||
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
|
||||
|
||||
@ -837,6 +838,8 @@ class FlashCausalLM(Model):
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
else:
|
||||
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.")
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
|
@ -292,6 +292,16 @@ def _attn_fwd_inner(
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 1,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user