wip fix tunableop

This commit is contained in:
fxmarty 2024-05-02 08:15:52 +00:00
parent a509360619
commit 2677bf856a
3 changed files with 124 additions and 160 deletions

View File

@ -95,8 +95,8 @@ RUN pip uninstall -y triton && \
cd triton/python && \ cd triton/python && \
pip install . pip install .
# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && git checkout d05863883b7b61eb5875abcb6cb6b32fa678beeb && pip install -r requirements.txt --no-cache-dir RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir # RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
ARG _GLIBCXX_USE_CXX11_ABI="1" ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda" ARG CMAKE_PREFIX_PATH="/opt/conda"
@ -113,98 +113,102 @@ ARG BUILD_CAFFE2="0" \
USE_FLASH_ATTENTION="0" \ USE_FLASH_ATTENTION="0" \
USE_MEM_EFF_ATTENTION="0" USE_MEM_EFF_ATTENTION="0"
# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
# FROM base AS kernel-builder # Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
# ENV HIP_FORCE_DEV_KERNARG=1
FROM base AS kernel-builder
# # Build vllm kernels # # Build vllm kernels
# FROM kernel-builder AS vllm-builder FROM kernel-builder AS vllm-builder
# WORKDIR /usr/src WORKDIR /usr/src
# COPY server/Makefile-vllm Makefile COPY server/Makefile-vllm Makefile
# # Build specific version of vllm # Build specific version of vllm
# RUN make build-vllm-rocm RUN make build-vllm-rocm
# # Build Flash Attention v2 kernels # Build Flash Attention v2 kernels
# FROM kernel-builder AS flash-att-v2-builder FROM kernel-builder AS flash-att-v2-builder
# WORKDIR /usr/src WORKDIR /usr/src
# COPY server/Makefile-flash-att-v2 Makefile COPY server/Makefile-flash-att-v2 Makefile
# # Build specific version of flash attention v2 # Build specific version of flash attention v2
# RUN make build-flash-attention-v2-rocm RUN make build-flash-attention-v2-rocm
# # Build Transformers CUDA kernels (gpt-neox and bloom) # Build Transformers CUDA kernels (gpt-neox and bloom)
# FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
# WORKDIR /usr/src WORKDIR /usr/src
# COPY server/custom_kernels/ . COPY server/custom_kernels/ .
# RUN python setup.py build RUN python setup.py build
# # Build exllama kernels # Build exllama kernels
# FROM kernel-builder as exllama-kernels-builder FROM kernel-builder as exllama-kernels-builder
# WORKDIR /usr/src WORKDIR /usr/src
# COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
# RUN python setup.py build RUN python setup.py build
# # Build exllama v2 kernels # Build exllama v2 kernels
# FROM kernel-builder as exllamav2-kernels-builder FROM kernel-builder as exllamav2-kernels-builder
# WORKDIR /usr/src WORKDIR /usr/src
# COPY server/exllamav2_kernels/ . COPY server/exllamav2_kernels/ .
# RUN python setup.py build RUN python setup.py build
# FROM base as base-copy FROM base as base-copy
# # Text Generation Inference base env # Text Generation Inference base env
# ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
# HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
# PORT=80 \ PORT=80
# HIP_FORCE_DEV_KERNARG=1
# # Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# # Copy build artifacts from flash attention v2 builder # Copy build artifacts from flash attention v2 builder
# COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# # Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# # Copy build artifacts from exllama kernels builder # Copy build artifacts from exllama kernels builder
# COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# # Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
# COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# # Install server # Install server
# COPY proto proto COPY proto proto
# COPY server server COPY server server
# COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
# # pip install -r requirements_rocm.txt && \ # pip install -r requirements_rocm.txt && \
# #pip install ".[accelerate, peft, outlines]" --no-cache-dir #pip install ".[accelerate, peft, outlines]" --no-cache-dir
# # Install benchmarker # Install benchmarker
# COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# # Install router # Install router
# COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# # Install launcher # Install launcher
# COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# RUN cd server && \ RUN cd server && \
# make gen-server && \ make gen-server && \
# pip install -r requirements_rocm.txt pip install -r requirements_rocm.txt
# # AWS Sagemaker compatible image # AWS Sagemaker compatible image
# FROM base-copy as sagemaker FROM base-copy as sagemaker
# COPY sagemaker-entrypoint.sh entrypoint.sh COPY sagemaker-entrypoint.sh entrypoint.sh
# RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh
# ENTRYPOINT ["./entrypoint.sh"] ENTRYPOINT ["./entrypoint.sh"]
# # Final image # Final image
# FROM base-copy FROM base-copy
# # ENTRYPOINT ["text-generation-launcher"] # ENTRYPOINT ["text-generation-launcher"]
# # CMD ["--json-output"] # CMD ["--json-output"]
# NOTE: Temporarily, for TGI, please mount a volume and install locally the server with `cd /tgi/server && pip install ".[accelerate, peft, outlines]" --no-cache-dir`

View File

@ -770,7 +770,9 @@ class FlashCausalLM(Model):
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
logger.info("calling self.generate_token(batch)")
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
logger.info("end it")
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
@ -824,18 +826,20 @@ class FlashCausalLM(Model):
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): # if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): # if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
torch.cuda.tunable.tuning_enable(True) # torch.cuda.tunable.tuning_enable(True)
# logger.info("enable tuning here")
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
total_seqlens = list(range(2)) for seqlen in range(1, 3):
for seqlen in total_seqlens: logger.info(f"Warming up TunableOp for seqlen={seqlen}")
logger.info(f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen, max_s, max_bt)
self.tunableop_warmup(seqlen, max_s, max_bt) logger.info("call write file")
torch.cuda.tunable.write_file() torch.cuda.tunable.write_file()
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
logger.info("finished tunable op")
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int): def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
@ -843,10 +847,10 @@ class FlashCausalLM(Model):
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# TODO: is this correct?
input_lengths = ( input_lengths = (
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
) )
bs = 1
block_tables = ( block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device) torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs) .repeat(bs)
@ -854,6 +858,7 @@ class FlashCausalLM(Model):
) )
kv_cache = get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
logger.info("call self.model.forward")
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,

View File

@ -6,6 +6,7 @@ _PARTITION_SIZE = 512
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
@ -61,43 +62,21 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
if IS_CUDA_SYSTEM: ops.paged_attention_v1(
from vllm._C import ops out,
query,
ops.paged_attention_v1( key_cache,
out, value_cache,
query, kv_head_mapping,
key_cache, softmax_scale,
value_cache, block_tables,
kv_head_mapping, input_lengths,
softmax_scale, block_size,
block_tables, max_s,
input_lengths, None,
block_size, "auto",
max_s, 1.0,
None, )
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
@ -113,45 +92,21 @@ def attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if IS_CUDA_SYSTEM: ops.paged_attention_v2(
from vllm._C import ops out,
exp_sums,
ops.paged_attention_v2( max_logits,
out, tmp_output,
exp_sums, query,
max_logits, key_cache,
tmp_output, value_cache,
query, kv_head_mapping,
key_cache, softmax_scale,
value_cache, block_tables,
kv_head_mapping, input_lengths,
softmax_scale, block_size,
block_tables, max_s,
input_lengths, None,
block_size, "auto",
max_s, 1.0,
None, )
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
from vllm import attention_ops
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")