diff --git a/Dockerfile_amd b/Dockerfile_amd index a261ae2a..8bbcbbe8 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -95,8 +95,8 @@ RUN pip uninstall -y triton && \ cd triton/python && \ pip install . -# RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && git checkout d05863883b7b61eb5875abcb6cb6b32fa678beeb && pip install -r requirements.txt --no-cache-dir -RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir +RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir +# RUN git clone --depth 1 --recursive --single-branch --branch release/2.3 https://github.com/pytorch/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir ARG _GLIBCXX_USE_CXX11_ABI="1" ARG CMAKE_PREFIX_PATH="/opt/conda" @@ -113,98 +113,102 @@ ARG BUILD_CAFFE2="0" \ USE_FLASH_ATTENTION="0" \ USE_MEM_EFF_ATTENTION="0" -# RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install -# FROM base AS kernel-builder +# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm +# ENV HIP_FORCE_DEV_KERNARG=1 + +FROM base AS kernel-builder # # Build vllm kernels -# FROM kernel-builder AS vllm-builder -# WORKDIR /usr/src +FROM kernel-builder AS vllm-builder +WORKDIR /usr/src -# COPY server/Makefile-vllm Makefile +COPY server/Makefile-vllm Makefile -# # Build specific version of vllm -# RUN make build-vllm-rocm +# Build specific version of vllm +RUN make build-vllm-rocm -# # Build Flash Attention v2 kernels -# FROM kernel-builder AS flash-att-v2-builder -# WORKDIR /usr/src +# Build Flash Attention v2 kernels +FROM kernel-builder AS flash-att-v2-builder +WORKDIR /usr/src -# COPY server/Makefile-flash-att-v2 Makefile +COPY server/Makefile-flash-att-v2 Makefile -# # Build specific version of flash attention v2 -# RUN make build-flash-attention-v2-rocm +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2-rocm -# # Build Transformers CUDA kernels (gpt-neox and bloom) -# FROM kernel-builder as custom-kernels-builder -# WORKDIR /usr/src -# COPY server/custom_kernels/ . -# RUN python setup.py build +# Build Transformers CUDA kernels (gpt-neox and bloom) +FROM kernel-builder as custom-kernels-builder +WORKDIR /usr/src +COPY server/custom_kernels/ . +RUN python setup.py build -# # Build exllama kernels -# FROM kernel-builder as exllama-kernels-builder -# WORKDIR /usr/src -# COPY server/exllama_kernels/ . +# Build exllama kernels +FROM kernel-builder as exllama-kernels-builder +WORKDIR /usr/src +COPY server/exllama_kernels/ . -# RUN python setup.py build +RUN python setup.py build -# # Build exllama v2 kernels -# FROM kernel-builder as exllamav2-kernels-builder -# WORKDIR /usr/src -# COPY server/exllamav2_kernels/ . +# Build exllama v2 kernels +FROM kernel-builder as exllamav2-kernels-builder +WORKDIR /usr/src +COPY server/exllamav2_kernels/ . -# RUN python setup.py build +RUN python setup.py build -# FROM base as base-copy +FROM base as base-copy -# # Text Generation Inference base env -# ENV HUGGINGFACE_HUB_CACHE=/data \ -# HF_HUB_ENABLE_HF_TRANSFER=1 \ -# PORT=80 \ -# HIP_FORCE_DEV_KERNARG=1 +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 -# # Copy builds artifacts from vllm builder -# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy builds artifacts from vllm builder +COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# # Copy build artifacts from flash attention v2 builder -# COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# # Copy build artifacts from custom kernels builder -# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from custom kernels builder +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# # Copy build artifacts from exllama kernels builder -# COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# # Copy build artifacts from exllamav2 kernels builder -# COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from exllamav2 kernels builder +COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages -# # Install server -# COPY proto proto -# COPY server server -# COPY server/Makefile server/Makefile -# # pip install -r requirements_rocm.txt && \ -# #pip install ".[accelerate, peft, outlines]" --no-cache-dir +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile + # pip install -r requirements_rocm.txt && \ + #pip install ".[accelerate, peft, outlines]" --no-cache-dir -# # Install benchmarker -# COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark -# # Install router -# COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router -# # Install launcher -# COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher -# RUN cd server && \ -# make gen-server && \ -# pip install -r requirements_rocm.txt +RUN cd server && \ + make gen-server && \ + pip install -r requirements_rocm.txt -# # AWS Sagemaker compatible image -# FROM base-copy as sagemaker -# COPY sagemaker-entrypoint.sh entrypoint.sh -# RUN chmod +x entrypoint.sh +# AWS Sagemaker compatible image +FROM base-copy as sagemaker +COPY sagemaker-entrypoint.sh entrypoint.sh +RUN chmod +x entrypoint.sh -# ENTRYPOINT ["./entrypoint.sh"] +ENTRYPOINT ["./entrypoint.sh"] -# # Final image -# FROM base-copy +# Final image +FROM base-copy -# # ENTRYPOINT ["text-generation-launcher"] -# # CMD ["--json-output"] +# ENTRYPOINT ["text-generation-launcher"] +# CMD ["--json-output"] + +# NOTE: Temporarily, for TGI, please mount a volume and install locally the server with `cd /tgi/server && pip install ".[accelerate, peft, outlines]" --no-cache-dir` \ No newline at end of file diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 443feb37..3aca1042 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -770,7 +770,9 @@ class FlashCausalLM(Model): if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) + logger.info("calling self.generate_token(batch)") _, batch, _ = self.generate_token(batch) + logger.info("end it") except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " @@ -824,18 +826,20 @@ class FlashCausalLM(Model): else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): - if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): - torch.cuda.tunable.tuning_enable(True) + # if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): + # if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): + # torch.cuda.tunable.tuning_enable(True) + # logger.info("enable tuning here") - logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") - total_seqlens = list(range(2)) - for seqlen in total_seqlens: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen, max_s, max_bt) - torch.cuda.tunable.write_file() - torch.cuda.tunable.tuning_enable(False) + logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.") + for seqlen in range(1, 3): + logger.info(f"Warming up TunableOp for seqlen={seqlen}") + self.tunableop_warmup(seqlen, max_s, max_bt) + logger.info("call write file") + torch.cuda.tunable.write_file() + torch.cuda.tunable.tuning_enable(False) + logger.info("finished tunable op") return int(num_blocks * BLOCK_SIZE) def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int): @@ -843,10 +847,10 @@ class FlashCausalLM(Model): position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - # TODO: is this correct? input_lengths = ( torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s ) + bs = 1 block_tables = ( torch.arange(max_bt, dtype=torch.int32, device=self.device) .repeat(bs) @@ -854,6 +858,7 @@ class FlashCausalLM(Model): ) kv_cache = get_cache_manager().kv_cache + logger.info("call self.model.forward") self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index d47e0821..dda15fa6 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -6,6 +6,7 @@ _PARTITION_SIZE = 512 try: from vllm._C import cache_ops + from vllm._C import ops except Exception as e: raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}") @@ -61,43 +62,21 @@ def attention( # to parallelize. use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - if IS_CUDA_SYSTEM: - from vllm._C import ops - - ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - elif IS_ROCM_SYSTEM: - from vllm import attention_ops - - attention_ops.paged_attention_v1( - out, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - ) - else: - raise ValueError("vllm is not supported on your system") - + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) else: # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 @@ -113,45 +92,21 @@ def attention( ) max_logits = torch.empty_like(exp_sums) - if IS_CUDA_SYSTEM: - from vllm._C import ops - - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, - ) - elif IS_ROCM_SYSTEM: - from vllm import attention_ops - - attention_ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - ) - else: - raise ValueError("vllm is not supported on your system") + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + )