mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
wip fix tunableop
This commit is contained in:
parent
a509360619
commit
2677bf856a
144
Dockerfile_amd
144
Dockerfile_amd
@ -95,8 +95,8 @@ RUN pip uninstall -y triton && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
|
||||
# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && git checkout d05863883b7b61eb5875abcb6cb6b32fa678beeb && pip install -r requirements.txt --no-cache-dir
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
# RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
@ -113,98 +113,102 @@ ARG BUILD_CAFFE2="0" \
|
||||
USE_FLASH_ATTENTION="0" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
|
||||
# FROM base AS kernel-builder
|
||||
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
# ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
# # Build vllm kernels
|
||||
# FROM kernel-builder AS vllm-builder
|
||||
# WORKDIR /usr/src
|
||||
FROM kernel-builder AS vllm-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
# COPY server/Makefile-vllm Makefile
|
||||
COPY server/Makefile-vllm Makefile
|
||||
|
||||
# # Build specific version of vllm
|
||||
# RUN make build-vllm-rocm
|
||||
# Build specific version of vllm
|
||||
RUN make build-vllm-rocm
|
||||
|
||||
# # Build Flash Attention v2 kernels
|
||||
# FROM kernel-builder AS flash-att-v2-builder
|
||||
# WORKDIR /usr/src
|
||||
# Build Flash Attention v2 kernels
|
||||
FROM kernel-builder AS flash-att-v2-builder
|
||||
WORKDIR /usr/src
|
||||
|
||||
# COPY server/Makefile-flash-att-v2 Makefile
|
||||
COPY server/Makefile-flash-att-v2 Makefile
|
||||
|
||||
# # Build specific version of flash attention v2
|
||||
# RUN make build-flash-attention-v2-rocm
|
||||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2-rocm
|
||||
|
||||
# # Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
# FROM kernel-builder as custom-kernels-builder
|
||||
# WORKDIR /usr/src
|
||||
# COPY server/custom_kernels/ .
|
||||
# RUN python setup.py build
|
||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN python setup.py build
|
||||
|
||||
# # Build exllama kernels
|
||||
# FROM kernel-builder as exllama-kernels-builder
|
||||
# WORKDIR /usr/src
|
||||
# COPY server/exllama_kernels/ .
|
||||
# Build exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
# RUN python setup.py build
|
||||
RUN python setup.py build
|
||||
|
||||
# # Build exllama v2 kernels
|
||||
# FROM kernel-builder as exllamav2-kernels-builder
|
||||
# WORKDIR /usr/src
|
||||
# COPY server/exllamav2_kernels/ .
|
||||
# Build exllama v2 kernels
|
||||
FROM kernel-builder as exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
|
||||
# RUN python setup.py build
|
||||
RUN python setup.py build
|
||||
|
||||
# FROM base as base-copy
|
||||
FROM base as base-copy
|
||||
|
||||
# # Text Generation Inference base env
|
||||
# ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
# HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
# PORT=80 \
|
||||
# HIP_FORCE_DEV_KERNARG=1
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
# # Copy builds artifacts from vllm builder
|
||||
# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# # Copy build artifacts from flash attention v2 builder
|
||||
# COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# # Copy build artifacts from custom kernels builder
|
||||
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# # Copy build artifacts from exllama kernels builder
|
||||
# COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# # Copy build artifacts from exllamav2 kernels builder
|
||||
# COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# # Install server
|
||||
# COPY proto proto
|
||||
# COPY server server
|
||||
# COPY server/Makefile server/Makefile
|
||||
# # pip install -r requirements_rocm.txt && \
|
||||
# #pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
# pip install -r requirements_rocm.txt && \
|
||||
#pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# # Install benchmarker
|
||||
# COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# # Install router
|
||||
# COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# # Install launcher
|
||||
# COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# RUN cd server && \
|
||||
# make gen-server && \
|
||||
# pip install -r requirements_rocm.txt
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_rocm.txt
|
||||
|
||||
# # AWS Sagemaker compatible image
|
||||
# FROM base-copy as sagemaker
|
||||
# COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
# RUN chmod +x entrypoint.sh
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base-copy as sagemaker
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
# ENTRYPOINT ["./entrypoint.sh"]
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
|
||||
# # Final image
|
||||
# FROM base-copy
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
# # ENTRYPOINT ["text-generation-launcher"]
|
||||
# # CMD ["--json-output"]
|
||||
# ENTRYPOINT ["text-generation-launcher"]
|
||||
# CMD ["--json-output"]
|
||||
|
||||
# NOTE: Temporarily, for TGI, please mount a volume and install locally the server with `cd /tgi/server && pip install ".[accelerate, peft, outlines]" --no-cache-dir`
|
@ -770,7 +770,9 @@ class FlashCausalLM(Model):
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
logger.info("calling self.generate_token(batch)")
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
logger.info("end it")
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
@ -824,18 +826,20 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
# if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
# if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||
# torch.cuda.tunable.tuning_enable(True)
|
||||
# logger.info("enable tuning here")
|
||||
|
||||
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||
total_seqlens = list(range(2))
|
||||
for seqlen in total_seqlens:
|
||||
logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||
for seqlen in range(1, 3):
|
||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||
logger.info("call write file")
|
||||
torch.cuda.tunable.write_file()
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
logger.info("finished tunable op")
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
|
||||
@ -843,10 +847,10 @@ class FlashCausalLM(Model):
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# TODO: is this correct?
|
||||
input_lengths = (
|
||||
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
bs = 1
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
@ -854,6 +858,7 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
logger.info("call self.model.forward")
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -6,6 +6,7 @@ _PARTITION_SIZE = 512
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
from vllm._C import ops
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||
|
||||
@ -61,9 +62,6 @@ def attention(
|
||||
# to parallelize.
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
@ -79,25 +77,6 @@ def attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
@ -113,9 +92,6 @@ def attention(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
@ -134,24 +110,3 @@ def attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
Loading…
Reference in New Issue
Block a user