at last working!

This commit is contained in:
fxmarty 2024-04-18 23:31:28 +00:00
parent 06c3d4b1ec
commit 3016e1595f
8 changed files with 55 additions and 35 deletions

View File

@ -36,7 +36,7 @@ COPY launcher launcher
RUN cargo build --release
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:5.7 as base
FROM rocm/dev-ubuntu-22.04:6.0.2 as base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev && \
hipblas-dev \
hipblaslt-dev \
rocblas-dev \
hiprand-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3-dev && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7'
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
@ -75,8 +86,8 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
# Install PyTorch 2.3 RC compiled against RoCm 6.0
RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/rocm6.0
FROM base AS kernel-builder
@ -102,21 +113,21 @@ RUN make build-flash-attention-v2-rocm
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
FROM base as base-copy
@ -147,10 +158,8 @@ RUN pip install einops --no-cache-dir
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_rocm.txt && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# pip install -r requirements_rocm.txt && \
#pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
@ -159,6 +168,10 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN cd server && \
make gen-server && \
pip install -r requirements_rocm.txt
# AWS Sagemaker compatible image
FROM base-copy as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
@ -169,5 +182,5 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
# ENTRYPOINT ["text-generation-launcher"]
# CMD ["--json-output"]

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash-attention-v2-cuda:
@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
flash-attention-v2-rocm:
# Clone flash attention
pip install -U packaging ninja --no-cache-dir
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2
build-flash-attention-v2-rocm: flash-attention-v2-rocm
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
cd flash-attention-v2 && git submodule update --init --recursive
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install

View File

@ -14,11 +14,12 @@ install-vllm-cuda: build-vllm-cuda
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/fxmarty/vllm-public.git vllm
git clone https://github.com/fxmarty/rocm-vllm.git vllm
build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true

View File

@ -10,8 +10,9 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
return _Float16_2{
_Float16_2{static_cast<_Float16>(1.0f),
static_cast<_Float16>(1.0f)} / x.data};
}
#define hrcp __compat_hrcp

View File

@ -65,7 +65,7 @@ class CohereRotary(PositionRotaryEmbedding):
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
@ -73,7 +73,7 @@ class CohereRotary(PositionRotaryEmbedding):
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
ops.rotary_embedding(query, key, head_size, cos, sin, False)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."

View File

@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
from vllm._C import ops
@dataclass
@ -418,7 +418,7 @@ class IdeficsRMSNorm(nn.Module):
hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
hidden_states,
self.weight.data,

View File

@ -15,6 +15,7 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
@ -33,11 +34,16 @@ try:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:

View File

@ -793,7 +793,7 @@ try:
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
from vllm._C import ops
else:
dropout_layer_norm = None
@ -895,7 +895,7 @@ try:
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
hidden_states,
self.weight.data,
@ -915,7 +915,7 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
from vllm._C import ops
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
@ -970,7 +970,7 @@ try:
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
ops.rotary_embedding(query, key, head_size, cos, sin, True)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
@ -1231,6 +1231,5 @@ try:
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
except ImportError:
pass
except ImportError as e:
raise e