mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 13:02:12 +00:00
trying to update to ROCm 6.1
This commit is contained in:
parent
17f5c3078b
commit
a509360619
172
Dockerfile_amd
172
Dockerfile_amd
@ -36,7 +36,7 @@ COPY launcher launcher
|
|||||||
RUN cargo build --release
|
RUN cargo build --release
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:6.0.2 as base
|
FROM rocm/dev-ubuntu-22.04:6.1 as base
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
@ -86,109 +86,125 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
mamba init && \
|
mamba init && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
RUN pip install torch numpy --index-url https://download.pytorch.org/whl/rocm6.0
|
# Install flash-attention, torch dependencies
|
||||||
|
RUN pip install numpy einops ninja --no-cache-dir
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
RUN conda install intel::mkl-static intel::mkl-include
|
||||||
|
RUN pip uninstall -y triton && \
|
||||||
|
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||||
|
cd triton/python && \
|
||||||
|
pip install .
|
||||||
|
|
||||||
# Build Triton
|
# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && git checkout d05863883b7b61eb5875abcb6cb6b32fa678beeb && pip install -r requirements.txt --no-cache-dir
|
||||||
FROM kernel-builder as triton-builder
|
RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-triton Makefile
|
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||||
RUN make build-triton-rocm
|
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||||
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||||
|
ARG BUILD_CAFFE2="0" \
|
||||||
|
BUILD_CAFFE2_OPS="0" \
|
||||||
|
USE_CUDA="0" \
|
||||||
|
USE_ROCM="1" \
|
||||||
|
BUILD_TEST="0" \
|
||||||
|
USE_FBGEMM="0" \
|
||||||
|
USE_NNPACK="0" \
|
||||||
|
USE_QNNPACK="0" \
|
||||||
|
USE_XNNPACK="0" \
|
||||||
|
USE_FLASH_ATTENTION="0" \
|
||||||
|
USE_MEM_EFF_ATTENTION="0"
|
||||||
|
|
||||||
# Build vllm kernels
|
# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||||
FROM kernel-builder AS vllm-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-vllm Makefile
|
# FROM base AS kernel-builder
|
||||||
|
|
||||||
# Build specific version of vllm
|
# # Build vllm kernels
|
||||||
RUN make build-vllm-rocm
|
# FROM kernel-builder AS vllm-builder
|
||||||
|
# WORKDIR /usr/src
|
||||||
|
|
||||||
# Build Flash Attention v2 kernels
|
# COPY server/Makefile-vllm Makefile
|
||||||
FROM kernel-builder AS flash-att-v2-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY server/Makefile-flash-att-v2 Makefile
|
# # Build specific version of vllm
|
||||||
|
# RUN make build-vllm-rocm
|
||||||
|
|
||||||
# Build specific version of flash attention v2
|
# # Build Flash Attention v2 kernels
|
||||||
RUN make build-flash-attention-v2-rocm
|
# FROM kernel-builder AS flash-att-v2-builder
|
||||||
|
# WORKDIR /usr/src
|
||||||
|
|
||||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
# COPY server/Makefile-flash-att-v2 Makefile
|
||||||
FROM kernel-builder as custom-kernels-builder
|
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/custom_kernels/ .
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
|
||||||
|
|
||||||
# Build exllama kernels
|
# # Build specific version of flash attention v2
|
||||||
FROM kernel-builder as exllama-kernels-builder
|
# RUN make build-flash-attention-v2-rocm
|
||||||
WORKDIR /usr/src
|
|
||||||
COPY server/exllama_kernels/ .
|
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
# # Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||||
|
# FROM kernel-builder as custom-kernels-builder
|
||||||
|
# WORKDIR /usr/src
|
||||||
|
# COPY server/custom_kernels/ .
|
||||||
|
# RUN python setup.py build
|
||||||
|
|
||||||
# Build exllama v2 kernels
|
# # Build exllama kernels
|
||||||
FROM kernel-builder as exllamav2-kernels-builder
|
# FROM kernel-builder as exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
# WORKDIR /usr/src
|
||||||
COPY server/exllamav2_kernels/ .
|
# COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
# RUN python setup.py build
|
||||||
|
|
||||||
FROM base as base-copy
|
# # Build exllama v2 kernels
|
||||||
|
# FROM kernel-builder as exllamav2-kernels-builder
|
||||||
|
# WORKDIR /usr/src
|
||||||
|
# COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# RUN python setup.py build
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
|
||||||
PORT=80 \
|
|
||||||
HIP_FORCE_DEV_KERNARG=1
|
|
||||||
|
|
||||||
# Copy builds artifacts from triton builder
|
# FROM base as base-copy
|
||||||
COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# # Text Generation Inference base env
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
# ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
# HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
# PORT=80 \
|
||||||
|
# HIP_FORCE_DEV_KERNARG=1
|
||||||
|
|
||||||
# Copy build artifacts from flash attention v2 builder
|
# # Copy builds artifacts from vllm builder
|
||||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
# # Copy build artifacts from flash attention v2 builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
# COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from exllama kernels builder
|
# # Copy build artifacts from custom kernels builder
|
||||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from exllamav2 kernels builder
|
# # Copy build artifacts from exllama kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
# COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# # Copy build artifacts from exllamav2 kernels builder
|
||||||
RUN pip install einops --no-cache-dir
|
# COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install server
|
# # Install server
|
||||||
COPY proto proto
|
# COPY proto proto
|
||||||
COPY server server
|
# COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
# COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
# # pip install -r requirements_rocm.txt && \
|
||||||
make gen-server && \
|
# #pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
pip install -r requirements_rocm.txt && \
|
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
|
||||||
|
|
||||||
# Install benchmarker
|
# # Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
# COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# # Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
# COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# # Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
# COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# RUN cd server && \
|
||||||
FROM base-copy as sagemaker
|
# make gen-server && \
|
||||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
# pip install -r requirements_rocm.txt
|
||||||
RUN chmod +x entrypoint.sh
|
|
||||||
|
|
||||||
ENTRYPOINT ["./entrypoint.sh"]
|
# # AWS Sagemaker compatible image
|
||||||
|
# FROM base-copy as sagemaker
|
||||||
|
# COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
|
# RUN chmod +x entrypoint.sh
|
||||||
|
|
||||||
# Final image
|
# ENTRYPOINT ["./entrypoint.sh"]
|
||||||
FROM base-copy
|
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
# # Final image
|
||||||
CMD ["--json-output"]
|
# FROM base-copy
|
||||||
|
|
||||||
|
# # ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
# # CMD ["--json-output"]
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
triton-rocm:
|
|
||||||
pip uninstall -y triton
|
|
||||||
pip install -U ninja cmake wheel packaging --no-cache-dir
|
|
||||||
git clone https://github.com/ROCm/triton.git triton
|
|
||||||
|
|
||||||
build-triton-rocm: triton-rocm
|
|
||||||
cd triton && git fetch && git checkout b9e5290de8bf3a79c4e91ceed7e61b3c8d041b30
|
|
||||||
cd triton/python && python setup.py build
|
|
@ -18,7 +18,6 @@ vllm-rocm:
|
|||||||
|
|
||||||
build-vllm-rocm: vllm-rocm
|
build-vllm-rocm: vllm-rocm
|
||||||
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
||||||
cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch
|
|
||||||
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
||||||
|
|
||||||
install-vllm-rocm: build-vllm-rocm
|
install-vllm-rocm: build-vllm-rocm
|
||||||
|
@ -768,7 +768,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s = max_bt * get_cache_manager().block_size
|
max_s = max_bt * get_cache_manager().block_size
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -824,10 +825,16 @@ class FlashCausalLM(Model):
|
|||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||||
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
|
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||||
total_seqlens = list(range(2))
|
total_seqlens = list(range(2))
|
||||||
for seqlen in total_seqlens:
|
for seqlen in total_seqlens:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||||
|
torch.cuda.tunable.write_file()
|
||||||
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
|
@ -211,12 +211,13 @@ def attention(
|
|||||||
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
|
# NOTE: The Triton kernel silently outputs wrong results when using MQA/GQA and not
|
||||||
# repeating.
|
# repeating.
|
||||||
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
|
# TODO: just a sketch. Kind of need to abstract this `attention` function to enable some customization and pass those - let's sync with Nicolas for which implem he'd like
|
||||||
num_heads = q.shape[1]
|
|
||||||
num_kv_heads = k.shape[1]
|
# num_heads = q.shape[1]
|
||||||
if num_kv_heads != num_heads:
|
# num_kv_heads = k.shape[1]
|
||||||
# Interleave for MQA workaround.
|
# if num_kv_heads != num_heads:
|
||||||
k = repeat_kv(k, num_heads // num_kv_heads)
|
# # Interleave for MQA workaround.
|
||||||
v = repeat_kv(v, num_heads // num_kv_heads)
|
# k = repeat_kv(k, num_heads // num_kv_heads)
|
||||||
|
# v = repeat_kv(v, num_heads // num_kv_heads)
|
||||||
|
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
|
@ -293,7 +293,7 @@ def _attn_fwd_inner(
|
|||||||
num_warps=4,
|
num_warps=4,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def attn_fwd(
|
def attn_fwd(
|
||||||
@ -330,8 +330,8 @@ def attn_fwd(
|
|||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset_base,
|
philox_offset_base,
|
||||||
encoded_softmax,
|
encoded_softmax,
|
||||||
hq,
|
HQ: tl.constexpr,
|
||||||
hk,
|
HK:tl.constexpr,
|
||||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
MAX_SEQLENS_Q: tl.constexpr,
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
MAX_SEQLENS_K: tl.constexpr,
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
@ -414,14 +414,19 @@ def attn_fwd(
|
|||||||
# TODO: Should dropout and return encoded softmax be handled here?
|
# TODO: Should dropout and return encoded softmax be handled here?
|
||||||
return
|
return
|
||||||
|
|
||||||
is_mqa = hq != hk
|
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||||
off_h_k = off_h_q % hk if is_mqa else off_h_q
|
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||||
|
if GROUP_SIZE != 1:
|
||||||
|
off_h_k = off_h_q // GROUP_SIZE
|
||||||
|
else:
|
||||||
|
off_h_k = off_h_q
|
||||||
|
|
||||||
n_extra_tokens = 0
|
n_extra_tokens = 0
|
||||||
if seqlen_k < BLOCK_N:
|
if seqlen_k < BLOCK_N:
|
||||||
n_extra_tokens = BLOCK_N - seqlen_k
|
n_extra_tokens = BLOCK_N - seqlen_k
|
||||||
elif seqlen_k % BLOCK_N:
|
elif seqlen_k % BLOCK_N:
|
||||||
n_extra_tokens = seqlen_k % BLOCK_N
|
n_extra_tokens = seqlen_k % BLOCK_N
|
||||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
|
||||||
|
|
||||||
# Compute pointers for all the tensors used in this kernel.
|
# Compute pointers for all the tensors used in this kernel.
|
||||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||||
@ -467,7 +472,7 @@ def attn_fwd(
|
|||||||
bias_ptr = None
|
bias_ptr = None
|
||||||
if ENABLE_DROPOUT:
|
if ENABLE_DROPOUT:
|
||||||
batch_philox_offset = philox_offset_base \
|
batch_philox_offset = philox_offset_base \
|
||||||
+ (off_z * hq + off_h_q) \
|
+ (off_z * HQ + off_h_q) \
|
||||||
* seqlen_q * seqlen_k
|
* seqlen_q * seqlen_k
|
||||||
else:
|
else:
|
||||||
batch_philox_offset = 0
|
batch_philox_offset = 0
|
||||||
@ -494,7 +499,7 @@ def attn_fwd(
|
|||||||
# have native e^x support in HW.
|
# have native e^x support in HW.
|
||||||
qk_scale = sm_scale * 1.44269504089
|
qk_scale = sm_scale * 1.44269504089
|
||||||
# Q is loaded once at the beginning and shared by all N blocks.
|
# Q is loaded once at the beginning and shared by all N blocks.
|
||||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
|
||||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||||
|
|
||||||
# Here we compute how many full and masked blocks we have.
|
# Here we compute how many full and masked blocks we have.
|
||||||
@ -549,7 +554,7 @@ def attn_fwd(
|
|||||||
False,
|
False,
|
||||||
ENABLE_DROPOUT,
|
ENABLE_DROPOUT,
|
||||||
RETURN_ENCODED_SOFTMAX,
|
RETURN_ENCODED_SOFTMAX,
|
||||||
padded_head,
|
PADDED_HEAD,
|
||||||
)
|
)
|
||||||
block_min = block_max
|
block_min = block_max
|
||||||
block_max = n_blocks * BLOCK_N
|
block_max = n_blocks * BLOCK_N
|
||||||
@ -595,7 +600,7 @@ def attn_fwd(
|
|||||||
True,
|
True,
|
||||||
ENABLE_DROPOUT,
|
ENABLE_DROPOUT,
|
||||||
RETURN_ENCODED_SOFTMAX,
|
RETURN_ENCODED_SOFTMAX,
|
||||||
padded_head,
|
PADDED_HEAD,
|
||||||
)
|
)
|
||||||
# epilogue
|
# epilogue
|
||||||
acc = acc / l_i[:, None]
|
acc = acc / l_i[:, None]
|
||||||
@ -729,16 +734,8 @@ class _attention(torch.autograd.Function):
|
|||||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||||
|
|
||||||
# Get closest power of 2 over or equal to 32.
|
# Get closest power of 2 over or equal to 32.
|
||||||
unpadded_head_dims = {32, 64, 128}
|
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||||
if head_size not in unpadded_head_dims:
|
padded_d_model = max(padded_d_model, 16)
|
||||||
padded_d_model = None
|
|
||||||
for i in unpadded_head_dims:
|
|
||||||
if i > head_size:
|
|
||||||
padded_d_model = i
|
|
||||||
break
|
|
||||||
assert padded_d_model is not None
|
|
||||||
else:
|
|
||||||
padded_d_model = head_size
|
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||||
@ -781,8 +778,8 @@ class _attention(torch.autograd.Function):
|
|||||||
philox_seed=philox_seed,
|
philox_seed=philox_seed,
|
||||||
philox_offset_base=philox_offset,
|
philox_offset_base=philox_offset,
|
||||||
encoded_softmax=encoded_softmax,
|
encoded_softmax=encoded_softmax,
|
||||||
hq=nheads_q,
|
HQ=nheads_q,
|
||||||
hk=nheads_k,
|
HK=nheads_k,
|
||||||
ACTUAL_BLOCK_DMODEL=head_size,
|
ACTUAL_BLOCK_DMODEL=head_size,
|
||||||
MAX_SEQLENS_Q=max_seqlens_q,
|
MAX_SEQLENS_Q=max_seqlens_q,
|
||||||
MAX_SEQLENS_K=max_seqlens_k,
|
MAX_SEQLENS_K=max_seqlens_k,
|
||||||
|
Loading…
Reference in New Issue
Block a user