mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-06 17:52:07 +00:00
at last working!
This commit is contained in:
parent
06c3d4b1ec
commit
3016e1595f
@ -36,7 +36,7 @@ COPY launcher launcher
|
|||||||
RUN cargo build --release
|
RUN cargo build --release
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:5.7 as base
|
FROM rocm/dev-ubuntu-22.04:6.0.2 as base
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
@ -50,13 +50,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||||||
# Needed to build VLLM & flash.
|
# Needed to build VLLM & flash.
|
||||||
rocthrust-dev \
|
rocthrust-dev \
|
||||||
hipsparse-dev \
|
hipsparse-dev \
|
||||||
hipblas-dev && \
|
hipblas-dev \
|
||||||
|
hipblaslt-dev \
|
||||||
|
rocblas-dev \
|
||||||
|
hiprand-dev \
|
||||||
|
rocrand-dev \
|
||||||
|
miopen-hip-dev \
|
||||||
|
hipfft-dev \
|
||||||
|
hipcub-dev \
|
||||||
|
hipsolver-dev \
|
||||||
|
rccl-dev \
|
||||||
|
cmake \
|
||||||
|
python3-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
ARG MAMBA_VERSION=23.1.0-1
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
ARG PYTORCH_VERSION='2.2.0.dev0'
|
ARG PYTORCH_VERSION='2.3.0'
|
||||||
ARG ROCM_VERSION='5.7'
|
ARG ROCM_VERSION='6.0.2'
|
||||||
ARG PYTHON_VERSION='3.10.10'
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
@ -75,8 +86,8 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
mamba init && \
|
mamba init && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
|
# Install PyTorch 2.3 RC compiled against RoCm 6.0
|
||||||
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
|
RUN pip install torch numpy --index-url https://download.pytorch.org/whl/test/rocm6.0
|
||||||
|
|
||||||
FROM base AS kernel-builder
|
FROM base AS kernel-builder
|
||||||
|
|
||||||
@ -102,21 +113,21 @@ RUN make build-flash-attention-v2-rocm
|
|||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/custom_kernels/ .
|
COPY server/custom_kernels/ .
|
||||||
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
# Build exllama kernels
|
# Build exllama kernels
|
||||||
FROM kernel-builder as exllama-kernels-builder
|
FROM kernel-builder as exllama-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllama_kernels/ .
|
COPY server/exllama_kernels/ .
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
# Build exllama v2 kernels
|
# Build exllama v2 kernels
|
||||||
FROM kernel-builder as exllamav2-kernels-builder
|
FROM kernel-builder as exllamav2-kernels-builder
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
COPY server/exllamav2_kernels/ .
|
COPY server/exllamav2_kernels/ .
|
||||||
|
|
||||||
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
|
RUN PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
FROM base as base-copy
|
FROM base as base-copy
|
||||||
|
|
||||||
@ -147,10 +158,8 @@ RUN pip install einops --no-cache-dir
|
|||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
COPY server/Makefile server/Makefile
|
COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
# pip install -r requirements_rocm.txt && \
|
||||||
make gen-server && \
|
#pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
pip install -r requirements_rocm.txt && \
|
|
||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
@ -159,6 +168,10 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
|
|||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_rocm.txt
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base-copy as sagemaker
|
FROM base-copy as sagemaker
|
||||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||||
@ -169,5 +182,5 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
# ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
# CMD ["--json-output"]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
|
|
||||||
flash-attention-v2-cuda:
|
flash-attention-v2-cuda:
|
||||||
@ -18,12 +18,12 @@ install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
|||||||
flash-attention-v2-rocm:
|
flash-attention-v2-rocm:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/fxmarty/flash-attention-rocm flash-attention-v2
|
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
build-flash-attention-v2-rocm: flash-attention-v2-rocm
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm)
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive
|
cd flash-attention-v2 && git submodule update --init --recursive
|
||||||
cd flash-attention-v2 && PYTORCH_ROCM_ARCH=gfx90a python setup.py build
|
cd flash-attention-v2 && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
||||||
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
|
||||||
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
cd flash-attention-v2 && git submodule update --init --recursive && python setup.py install
|
||||||
|
@ -14,11 +14,12 @@ install-vllm-cuda: build-vllm-cuda
|
|||||||
vllm-rocm:
|
vllm-rocm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone https://github.com/fxmarty/vllm-public.git vllm
|
git clone https://github.com/fxmarty/rocm-vllm.git vllm
|
||||||
|
|
||||||
build-vllm-rocm: vllm-rocm
|
build-vllm-rocm: vllm-rocm
|
||||||
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
cd vllm && git fetch && git checkout ca6913b3c2ffacdcb7d15e914dc34adbc6c89479
|
||||||
cd vllm && python setup.py build
|
cd vllm && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch
|
||||||
|
cd vllm && PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
|
||||||
|
|
||||||
install-vllm-rocm: build-vllm-rocm
|
install-vllm-rocm: build-vllm-rocm
|
||||||
pip uninstall vllm -y || true
|
pip uninstall vllm -y || true
|
||||||
|
@ -10,8 +10,9 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||||
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
|
return _Float16_2{
|
||||||
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
|
_Float16_2{static_cast<_Float16>(1.0f),
|
||||||
|
static_cast<_Float16>(1.0f)} / x.data};
|
||||||
}
|
}
|
||||||
|
|
||||||
#define hrcp __compat_hrcp
|
#define hrcp __compat_hrcp
|
||||||
|
@ -65,7 +65,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
@ -73,7 +73,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
|
|||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -418,7 +418,7 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
|
@ -15,6 +15,7 @@ major, minor = torch.cuda.get_device_capability()
|
|||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
is_sm8x = major == 8 and minor >= 0
|
is_sm8x = major == 8 and minor >= 0
|
||||||
is_sm90 = major == 9 and minor == 0
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
is_sm94 = major == 9 and minor == 4
|
||||||
|
|
||||||
HAS_FLASH_ATTN = False
|
HAS_FLASH_ATTN = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
HAS_FLASH_ATTN_V2_CUDA = False
|
||||||
@ -33,11 +34,16 @@ try:
|
|||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
)
|
)
|
||||||
if not (is_sm8x or is_sm90):
|
if IS_CUDA_SYSTEM and not (is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
"Flash Attention V2"
|
"Flash Attention V2"
|
||||||
)
|
)
|
||||||
|
elif IS_ROCM_SYSTEM and not (is_sm8x or is_sm90 or is_sm94):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -793,7 +793,7 @@ try:
|
|||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
else:
|
else:
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
@ -895,7 +895,7 @@ try:
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
out = torch.empty_like(hidden_states)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
@ -915,7 +915,7 @@ try:
|
|||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
@ -970,7 +970,7 @@ try:
|
|||||||
head_size = query.shape[-1]
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
@ -1231,6 +1231,5 @@ try:
|
|||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
except ImportError as e:
|
||||||
except ImportError:
|
raise e
|
||||||
pass
|
|
||||||
|
Loading…
Reference in New Issue
Block a user