From 52bdcf797d95f733a02273ee7b4738309f067b85 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Mon, 30 Oct 2023 10:42:34 +0000 Subject: [PATCH] working rocm build --- Dockerfile_amd | 135 ++++++++++++++++++ server/requirements.txt | 6 +- .../custom_modeling/flash_llama_modeling.py | 75 +++++++--- .../custom_modeling/flash_mistral_modeling.py | 72 +++++----- .../custom_modeling/idefics_modeling.py | 86 +++++------ .../utils/flash_attn.py | 18 ++- .../utils/import_utils.py | 15 ++ server/text_generation_server/utils/layers.py | 82 ++++++----- .../utils/paged_attention.py | 3 + 9 files changed, 349 insertions(+), 143 deletions(-) create mode 100644 Dockerfile_amd create mode 100644 server/text_generation_server/utils/import_utils.py diff --git a/Dockerfile_amd b/Dockerfile_amd new file mode 100644 index 00000000..5ed8ec8a --- /dev/null +++ b/Dockerfile_amd @@ -0,0 +1,135 @@ +# Rust builder +FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + +# Text Generation Inference base image +FROM rocm/dev-ubuntu-20.04:5.7 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + ccache \ + curl \ + git \ + make \ + libssl-dev \ + g++ \ + wget \ + # Needed to build VLLM. + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev && \ + rm -rf /var/lib/apt/lists/* + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. + +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && mkdir .conda \ + && bash Miniconda3-latest-Linux-x86_64.sh -b \ + && rm -f Miniconda3-latest-Linux-x86_64.sh + +ENV PATH="/root/miniconda3/bin:${PATH}" +ARG PATH="/root/user/miniconda3/bin:${PATH}" +RUN conda init bash + +ARG PYTORCH_VERSION='2.2.0.dev0' +ARG ROCM_VERSION='5.7' +ARG PYTHON_VERSION='3.11.5' + +RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 +RUN pip install -U ninja + +WORKDIR /usr/src + +# Install VLLM. +RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm +WORKDIR /usr/src/vllm-public +RUN pip install -r requirements.txt +RUN python setup.py install + +# Install Flash Attention v1. +WORKDIR /usr/src +RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install + +# Not working for RoCm +# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build + +# COPY server/Makefile-flash-att Makefile + +# Build specific version of flash attention +# RUN make build-flash-attention + +# Build Transformers CUDA kernels +# NOTE: gpt-neox and bloom fused kernels + +# FROM kernel-builder as custom-kernels-builder +# WORKDIR /usr/src +# COPY server/custom_kernels/ . +# Build specific version of transformers +# RUN python setup.py build + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +# Copy build artifacts from flash attention builder +# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + +# Copy build artifacts from custom kernels builder +# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && pip3 install -r requirements.txt +RUN cd server && \ + make gen-server && \ + pip3 install ".[accelerate]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcherg +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# ENTRYPOINT ["text-generation-launcher"] +# CMD ["--json-output"] diff --git a/server/requirements.txt b/server/requirements.txt index 7c81c5f9..34159f89 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -4,7 +4,8 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13" +# bitsandbytes is broken on RoCm systems +# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" @@ -62,7 +63,8 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" -torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" +# We use nightly +torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 69608e1c..3d21a4ba 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,8 +26,10 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from loguru import logger + # Flash attention imports -import dropout_layer_norm +# import dropout_layer_norm from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( @@ -39,6 +41,9 @@ from text_generation_server.utils.layers import ( get_linear, ) +from vllm import layernorm_ops + +torch.set_printoptions(threshold=10000000, sci_mode=True) class LlamaConfig(PretrainedConfig): def __init__( @@ -121,28 +126,43 @@ class LlamaRMSNorm(nn.Module): return self.weight * hidden_states, residual else: - # faster post attention rms norm - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states + # We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used. + if residual is not None: + hidden_states += residual + residual = hidden_states - return normed_hidden_states, res + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + + # else: + # # faster post attention rms norm + # normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + # hidden_states, + # residual, + # self.weight, + # None, + # None, + # None, + # None, + # None, + # 0.0, + # self.variance_epsilon, + # 1.0, + # 0, + # None, + # False, + # True, # Activate RMSNorm + # ) + # if res is None: + # res = hidden_states + + # return normed_hidden_states, res def load_attention(config, prefix, weights): @@ -262,6 +282,11 @@ class FlashLlamaAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + # logger.info(f"query before rotary {query[:10, ..., :8]}") + # logger.info(f"cos before rotary {cos[:10]}") + # logger.info(f"sin before rotary {sin[:10]}") + # TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change... + # Flash Attention kernel may be usable since it is Triton-based self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) @@ -272,6 +297,9 @@ class FlashLlamaAttention(torch.nn.Module): # output tensor attn_output = torch.empty_like(query) + + # logger.info(f"query {query.shape}") + # logger.info(f"query piece {query[:10, ..., :8]}") # Prefill if cu_seqlen_prefill is not None: # flash attention @@ -297,6 +325,9 @@ class FlashLlamaAttention(torch.nn.Module): input_lengths, max_s, ) + + # logger.info(f"attn_output {attn_output.shape}") + # logger.info(f"attn_output piece {attn_output[:10, ..., :8]}") return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 2d731406..cf0ddb00 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple # Flash attention imports -import dropout_layer_norm +# import dropout_layer_norm from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 @@ -110,45 +110,45 @@ class MistralRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + # if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) - return self.weight * hidden_states, residual - else: - # faster post attention rms norm - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states + return self.weight * hidden_states, residual + # else: + # # faster post attention rms norm + # normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + # hidden_states, + # residual, + # self.weight, + # None, + # None, + # None, + # None, + # None, + # 0.0, + # self.variance_epsilon, + # 1.0, + # 0, + # None, + # False, + # True, # Activate RMSNorm + # ) + # if res is None: + # res = hidden_states - return normed_hidden_states, res + # return normed_hidden_states, res def load_attention(config, prefix, weights): diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 1ffe6276..4a082f95 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -55,7 +55,7 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, FastLinear, ) -import dropout_layer_norm +# import dropout_layer_norm @dataclass @@ -354,54 +354,54 @@ class IdeficsRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + # if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) - return self.weight * hidden_states - else: - # faster post attention rms norm - unwrap = False - if len(hidden_states.shape) > 2: - unwrap = True - shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, shape[-1]) + return self.weight * hidden_states + # else: + # # faster post attention rms norm + # unwrap = False + # if len(hidden_states.shape) > 2: + # unwrap = True + # shape = hidden_states.shape + # hidden_states = hidden_states.reshape(-1, shape[-1]) - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states + # normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + # hidden_states, + # residual, + # self.weight, + # None, + # None, + # None, + # None, + # None, + # 0.0, + # self.variance_epsilon, + # 1.0, + # 0, + # None, + # False, + # True, # Activate RMSNorm + # ) + # if res is None: + # res = hidden_states - if unwrap: - normed_hidden_states = normed_hidden_states.view(*shape) + # if unwrap: + # normed_hidden_states = normed_hidden_states.view(*shape) - return normed_hidden_states + # return normed_hidden_states # this was adapted from LlamaMLP diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 8f0fcee6..a24b3ae0 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -3,6 +3,8 @@ import torch from loguru import logger +from .import_utils import is_cuda_system, is_rocm_system + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") @@ -41,10 +43,17 @@ except ImportError as e: "or install flash attention with `cd server && make install install-flash-attention`" ) from e - if not (is_sm75 or is_sm8x or is_sm90): + if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90): raise ImportError( f"GPU with CUDA capability {major} {minor} is not supported" ) from e + elif is_rocm_system(): + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + logger.warning(f"Unable to use Flash Attention V2: {e}") HAS_FLASH_ATTN = True @@ -59,6 +68,7 @@ def attention( softmax_scale, window_size_left=-1, ): + # logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}") if HAS_FLASH_ATTN_V2: return flash_attn_2_cuda.varlen_fwd( q, @@ -78,7 +88,8 @@ def attention( False, None, ) - + + # logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}") if HAS_FLASH_ATTN: if window_size_left != -1: raise NotImplementedError( @@ -124,7 +135,8 @@ def attention( softmax_scale, False, True, - False, + False, # is_deterministic => rocm specific argument + False, # return_softmax 0, None, ) diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py new file mode 100644 index 00000000..98b8daba --- /dev/null +++ b/server/text_generation_server/utils/import_utils.py @@ -0,0 +1,15 @@ +import subprocess + +def is_cuda_system(): + try: + subprocess.check_output("nvidia-smi") + return True + except Exception: + return False + +def is_rocm_system(): + try: + subprocess.check_output("rocm-smi") + return True + except Exception: + return False \ No newline at end of file diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7bb95dd2..c6a099da 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -509,50 +509,50 @@ class TensorParallelEmbedding(nn.Module): try: - import dropout_layer_norm + # import dropout_layer_norm class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states + # if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: - residual = hidden_states + return super(FastLayerNorm, self).forward(hidden_states), residual + # else: + # ( + # normed_hidden_states, + # residual, + # *rest, + # ) = dropout_layer_norm.dropout_add_ln_fwd( + # hidden_states, + # residual, + # self.weight, + # self.bias, + # None, + # None, + # None, + # None, + # 0.0, + # self.eps, + # 1.0, + # 0, + # None, + # False, + # False, + # ) + # if residual is None: + # residual = hidden_states - return normed_hidden_states, residual + # return normed_hidden_states, residual except ImportError: pass try: - from flash_attn.layers.rotary import RotaryEmbedding - import rotary_emb + # from flash_attn.layers.rotary import RotaryEmbedding + # import rotary_emb def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( @@ -692,11 +692,19 @@ try: def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): rotary_dim = cos.shape[-1] - x1 = x[..., :rotary_dim] - x2 = x[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) - return x + dtype = x.dtype + x_upcast = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + x1 = x_upcast[..., :rotary_dim] + x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] + + # rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + # Flash Attention kernel casts everything to float, not sure why. In place op here + x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) + x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 57a59599..2b8e3509 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -4,6 +4,8 @@ import torch from vllm import cache_ops from vllm import attention_ops +from loguru import logger + _PARTITION_SIZE = 512 @@ -54,6 +56,7 @@ def attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + logger.info(f"paged attention use_v1 {use_v1}") if use_v1: attention_ops.paged_attention_v1( out,