working rocm build

This commit is contained in:
Felix Marty 2023-10-30 10:42:34 +00:00
parent f9910d13e2
commit 52bdcf797d
9 changed files with 349 additions and 143 deletions

135
Dockerfile_amd Normal file
View File

@ -0,0 +1,135 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
ARG GIT_SHA
ARG DOCKER_LABEL
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY launcher launcher
RUN cargo build --release
# Text Generation Inference base image
FROM rocm/dev-ubuntu-20.04:5.7 as base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git \
make \
libssl-dev \
g++ \
wget \
# Needed to build VLLM.
rocthrust-dev \
hipsparse-dev \
hipblas-dev && \
rm -rf /var/lib/apt/lists/*
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir .conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PATH="/root/user/miniconda3/bin:${PATH}"
RUN conda init bash
ARG PYTORCH_VERSION='2.2.0.dev0'
ARG ROCM_VERSION='5.7'
ARG PYTHON_VERSION='3.11.5'
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
RUN pip install -U ninja
WORKDIR /usr/src
# Install VLLM.
RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm
WORKDIR /usr/src/vllm-public
RUN pip install -r requirements.txt
RUN python setup.py install
# Install Flash Attention v1.
WORKDIR /usr/src
RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install
# Not working for RoCm
# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build
# COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
# RUN make build-flash-attention
# Build Transformers CUDA kernels
# NOTE: gpt-neox and bloom fused kernels
# FROM kernel-builder as custom-kernels-builder
# WORKDIR /usr/src
# COPY server/custom_kernels/ .
# Build specific version of transformers
# RUN python setup.py build
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Copy build artifacts from flash attention builder
# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && pip3 install -r requirements.txt
RUN cd server && \
make gen-server && \
pip3 install ".[accelerate]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcherg
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# ENTRYPOINT ["text-generation-launcher"]
# CMD ["--json-output"]

View File

@ -4,7 +4,8 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13" # bitsandbytes is broken on RoCm systems
# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
@ -62,7 +63,8 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13" texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" # We use nightly
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13" transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -26,8 +26,10 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from loguru import logger
# Flash attention imports # Flash attention imports
import dropout_layer_norm # import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -39,6 +41,9 @@ from text_generation_server.utils.layers import (
get_linear, get_linear,
) )
from vllm import layernorm_ops
torch.set_printoptions(threshold=10000000, sci_mode=True)
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
@ -121,28 +126,43 @@ class LlamaRMSNorm(nn.Module):
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
else: else:
# faster post attention rms norm # We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used.
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( if residual is not None:
hidden_states, hidden_states += residual
residual, residual = hidden_states
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
# else:
# # faster post attention rms norm
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
# hidden_states,
# residual,
# self.weight,
# None,
# None,
# None,
# None,
# None,
# 0.0,
# self.variance_epsilon,
# 1.0,
# 0,
# None,
# False,
# True, # Activate RMSNorm
# )
# if res is None:
# res = hidden_states
# return normed_hidden_states, res
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -262,6 +282,11 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# logger.info(f"query before rotary {query[:10, ..., :8]}")
# logger.info(f"cos before rotary {cos[:10]}")
# logger.info(f"sin before rotary {sin[:10]}")
# TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change...
# Flash Attention kernel may be usable since it is Triton-based
self.rotary_emb(query, cos, sin) self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
@ -272,6 +297,9 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# logger.info(f"query {query.shape}")
# logger.info(f"query piece {query[:10, ..., :8]}")
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
@ -297,6 +325,9 @@ class FlashLlamaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
) )
# logger.info(f"attn_output {attn_output.shape}")
# logger.info(f"attn_output piece {attn_output[:10, ..., :8]}")
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import dropout_layer_norm # import dropout_layer_norm
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
@ -110,45 +110,45 @@ class MistralRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: # if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
else: # else:
# faster post attention rms norm # # faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( # normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states, # hidden_states,
residual, # residual,
self.weight, # self.weight,
None, # None,
None, # None,
None, # None,
None, # None,
None, # None,
0.0, # 0.0,
self.variance_epsilon, # self.variance_epsilon,
1.0, # 1.0,
0, # 0,
None, # None,
False, # False,
True, # Activate RMSNorm # True, # Activate RMSNorm
) # )
if res is None: # if res is None:
res = hidden_states # res = hidden_states
return normed_hidden_states, res # return normed_hidden_states, res
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):

View File

@ -55,7 +55,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
import dropout_layer_norm # import dropout_layer_norm
@dataclass @dataclass
@ -354,54 +354,54 @@ class IdeficsRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: # if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states
else: # else:
# faster post attention rms norm # # faster post attention rms norm
unwrap = False # unwrap = False
if len(hidden_states.shape) > 2: # if len(hidden_states.shape) > 2:
unwrap = True # unwrap = True
shape = hidden_states.shape # shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1]) # hidden_states = hidden_states.reshape(-1, shape[-1])
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( # normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states, # hidden_states,
residual, # residual,
self.weight, # self.weight,
None, # None,
None, # None,
None, # None,
None, # None,
None, # None,
0.0, # 0.0,
self.variance_epsilon, # self.variance_epsilon,
1.0, # 1.0,
0, # 0,
None, # None,
False, # False,
True, # Activate RMSNorm # True, # Activate RMSNorm
) # )
if res is None: # if res is None:
res = hidden_states # res = hidden_states
if unwrap: # if unwrap:
normed_hidden_states = normed_hidden_states.view(*shape) # normed_hidden_states = normed_hidden_states.view(*shape)
return normed_hidden_states # return normed_hidden_states
# this was adapted from LlamaMLP # this was adapted from LlamaMLP

View File

@ -3,6 +3,8 @@ import torch
from loguru import logger from loguru import logger
from .import_utils import is_cuda_system, is_rocm_system
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -41,10 +43,17 @@ except ImportError as e:
"or install flash attention with `cd server && make install install-flash-attention`" "or install flash attention with `cd server && make install install-flash-attention`"
) from e ) from e
if not (is_sm75 or is_sm8x or is_sm90): if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) from e ) from e
elif is_rocm_system():
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}") logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
@ -59,6 +68,7 @@ def attention(
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
): ):
# logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}")
if HAS_FLASH_ATTN_V2: if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
@ -78,7 +88,8 @@ def attention(
False, False,
None, None,
) )
# logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}")
if HAS_FLASH_ATTN: if HAS_FLASH_ATTN:
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
@ -124,7 +135,8 @@ def attention(
softmax_scale, softmax_scale,
False, False,
True, True,
False, False, # is_deterministic => rocm specific argument
False, # return_softmax
0, 0,
None, None,
) )

View File

@ -0,0 +1,15 @@
import subprocess
def is_cuda_system():
try:
subprocess.check_output("nvidia-smi")
return True
except Exception:
return False
def is_rocm_system():
try:
subprocess.check_output("rocm-smi")
return True
except Exception:
return False

View File

@ -509,50 +509,50 @@ class TensorParallelEmbedding(nn.Module):
try: try:
import dropout_layer_norm # import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: # if hidden_states.shape[-1] > 8192:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual return super(FastLayerNorm, self).forward(hidden_states), residual
else: # else:
( # (
normed_hidden_states, # normed_hidden_states,
residual, # residual,
*rest, # *rest,
) = dropout_layer_norm.dropout_add_ln_fwd( # ) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states, # hidden_states,
residual, # residual,
self.weight, # self.weight,
self.bias, # self.bias,
None, # None,
None, # None,
None, # None,
None, # None,
0.0, # 0.0,
self.eps, # self.eps,
1.0, # 1.0,
0, # 0,
None, # None,
False, # False,
False, # False,
) # )
if residual is None: # if residual is None:
residual = hidden_states # residual = hidden_states
return normed_hidden_states, residual # return normed_hidden_states, residual
except ImportError: except ImportError:
pass pass
try: try:
from flash_attn.layers.rotary import RotaryEmbedding # from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb # import rotary_emb
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
@ -692,11 +692,19 @@ try:
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) dtype = x.dtype
return x x_upcast = x.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
x1 = x_upcast[..., :rotary_dim]
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
# rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
# Flash Attention kernel casts everything to float, not sure why. In place op here
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):

View File

@ -4,6 +4,8 @@ import torch
from vllm import cache_ops from vllm import cache_ops
from vllm import attention_ops from vllm import attention_ops
from loguru import logger
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -54,6 +56,7 @@ def attention(
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
logger.info(f"paged attention use_v1 {use_v1}")
if use_v1: if use_v1:
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
out, out,