mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-06 09:42:09 +00:00
at last working!
This commit is contained in:
parent
06c3d4b1ec
commit
3016e1595f
@ -36,7 +36,7 @@ COPY launcher launcher
|
||||
RUN cargo build --release
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:5.7 as base
|
||||
FROM rocm/dev-ubuntu-22.04:6.0.2 as base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev && \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||
ARG ROCM_VERSION='5.7'
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
@ -75,8 +86,8 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
||||
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
|
||||
# Install PyTorch 2.3 RC compiled against RoCm 6.0
|
||||
RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/rocm6.0
|
||||
|
||||
FROM base AS kernel-builder
|
||||
|
||||
@ -102,21 +113,21 @@ RUN make build-flash-attention-v2-rocm
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
||||
# Build exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
||||
# Build exllama v2 kernels
|
||||
FROM kernel-builder as exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
||||
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
||||
FROM base as base-copy
|
||||
|
||||
@ -147,10 +158,8 @@ RUN pip install einops --no-cache-dir
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
COPY server/Makefile server/Makefile
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_rocm.txt && \
|
||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
# pip install -r requirements_rocm.txt && \
|
||||
#pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
@ -159,6 +168,10 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_rocm.txt
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base-copy as sagemaker
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
@ -169,5 +182,5 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||
# Final image
|
||||
FROM base-copy
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
# ENTRYPOINT ["text-generation-launcher"]
|
||||
# CMD ["--json-output"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
|
||||
|
||||
flash-attention-v2-cuda:
|
||||
@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||
flash-attention-v2-rocm:
|
||||
# Clone flash attention
|
||||
pip install -U packaging ninja --no-cache-dir
|
||||
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
|
||||
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2
|
||||
|
||||
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
||||
cd flash-attention-v2 && git submodule update --init --recursive
|
||||
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
||||
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
||||
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
||||
|
@ -14,11 +14,12 @@ install-vllm-cuda: build-vllm-cuda
|
||||
vllm-rocm:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/fxmarty/vllm-public.git vllm
|
||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm
|
||||
|
||||
build-vllm-rocm: vllm-rocm
|
||||
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||
cd vllm && python setup.py build
|
||||
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
||||
cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch
|
||||
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
||||
|
||||
install-vllm-rocm: build-vllm-rocm
|
||||
pip uninstall vllm -y || true
|
||||
|
@ -10,8 +10,9 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
||||
return _Float16_2{
|
||||
_Float16_2{static_cast<_Float16>(1.0f),
|
||||
static_cast<_Float16>(1.0f)} / x.data};
|
||||
}
|
||||
|
||||
#define hrcp __compat_hrcp
|
||||
|
@ -65,7 +65,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm._C import ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
@ -73,7 +73,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
|
@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
from vllm._C import ops
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -418,7 +418,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
|
@ -15,6 +15,7 @@ major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
is_sm94 = major == 9 and minor == 4
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
@ -33,11 +34,16 @@ try:
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||
)
|
||||
if not (is_sm8x or is_sm90):
|
||||
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
|
||||
raise ImportError(
|
||||
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||
except ImportError as e:
|
||||
|
@ -793,7 +793,7 @@ try:
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import layernorm_ops
|
||||
from vllm._C import ops
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
@ -895,7 +895,7 @@ try:
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
layernorm_ops.rms_norm(
|
||||
ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
@ -915,7 +915,7 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm._C import ops
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
@ -970,7 +970,7 @@ try:
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
@ -1231,6 +1231,5 @@ try:
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
Loading…
Reference in New Issue
Block a user