mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
working rocm build
This commit is contained in:
parent
f9910d13e2
commit
52bdcf797d
135
Dockerfile_amd
Normal file
135
Dockerfile_amd
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
# Rust builder
|
||||||
|
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
||||||
|
FROM chef as planner
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
|
ARG GIT_SHA
|
||||||
|
ARG DOCKER_LABEL
|
||||||
|
|
||||||
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||||
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY benchmark benchmark
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
|
|
||||||
|
# Text Generation Inference base image
|
||||||
|
FROM rocm/dev-ubuntu-20.04:5.7 as base
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
ca-certificates \
|
||||||
|
ccache \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
make \
|
||||||
|
libssl-dev \
|
||||||
|
g++ \
|
||||||
|
wget \
|
||||||
|
# Needed to build VLLM.
|
||||||
|
rocthrust-dev \
|
||||||
|
hipsparse-dev \
|
||||||
|
hipblas-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
|
||||||
|
RUN wget \
|
||||||
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& mkdir .conda \
|
||||||
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
ARG PATH="/root/user/miniconda3/bin:${PATH}"
|
||||||
|
RUN conda init bash
|
||||||
|
|
||||||
|
ARG PYTORCH_VERSION='2.2.0.dev0'
|
||||||
|
ARG ROCM_VERSION='5.7'
|
||||||
|
ARG PYTHON_VERSION='3.11.5'
|
||||||
|
|
||||||
|
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
|
||||||
|
RUN pip install -U ninja
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
# Install VLLM.
|
||||||
|
RUN git clone https://github.com/fxmarty/vllm-public.git && cd vllm-public && git checkout --track origin/port-to-rocm
|
||||||
|
WORKDIR /usr/src/vllm-public
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
RUN python setup.py install
|
||||||
|
|
||||||
|
# Install Flash Attention v1.
|
||||||
|
WORKDIR /usr/src
|
||||||
|
RUN git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git && cd flash-attention && git submodule init && git submodule update && python setup.py install
|
||||||
|
|
||||||
|
# Not working for RoCm
|
||||||
|
# RUN cd flash-attention/csrc/rotary && python setup.py build && cd flash-attention/csrc/layer_norm && python setup.py build
|
||||||
|
|
||||||
|
# COPY server/Makefile-flash-att Makefile
|
||||||
|
|
||||||
|
# Build specific version of flash attention
|
||||||
|
# RUN make build-flash-attention
|
||||||
|
|
||||||
|
# Build Transformers CUDA kernels
|
||||||
|
# NOTE: gpt-neox and bloom fused kernels
|
||||||
|
|
||||||
|
# FROM kernel-builder as custom-kernels-builder
|
||||||
|
# WORKDIR /usr/src
|
||||||
|
# COPY server/custom_kernels/ .
|
||||||
|
# Build specific version of transformers
|
||||||
|
# RUN python setup.py build
|
||||||
|
|
||||||
|
# Text Generation Inference base env
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention builder
|
||||||
|
# COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
# COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from custom kernels builder
|
||||||
|
# COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && pip3 install -r requirements.txt
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip3 install ".[accelerate]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcherg
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
# ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
# CMD ["--json-output"]
|
@ -4,7 +4,8 @@ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
# bitsandbytes is broken on RoCm systems
|
||||||
|
# bitsandbytes==0.41.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -62,7 +63,8 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
|
texttable==1.6.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
# We use nightly
|
||||||
|
torch>2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.33.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -26,8 +26,10 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
# import dropout_layer_norm
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
@ -39,6 +41,9 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
torch.set_printoptions(threshold=10000000, sci_mode=True)
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
class LlamaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -121,28 +126,43 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
|
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
else:
|
else:
|
||||||
# faster post attention rms norm
|
# We use VLLM kernels that are compiled for RoCm instead of Flash Attention ones that can't be used.
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
if residual is not None:
|
||||||
hidden_states,
|
hidden_states += residual
|
||||||
residual,
|
residual = hidden_states
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.variance_epsilon,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
True, # Activate RMSNorm
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, res
|
out = torch.empty_like(hidden_states)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out, residual
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# # faster post attention rms norm
|
||||||
|
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
# hidden_states,
|
||||||
|
# residual,
|
||||||
|
# self.weight,
|
||||||
|
# None,
|
||||||
|
# None,
|
||||||
|
# None,
|
||||||
|
# None,
|
||||||
|
# None,
|
||||||
|
# 0.0,
|
||||||
|
# self.variance_epsilon,
|
||||||
|
# 1.0,
|
||||||
|
# 0,
|
||||||
|
# None,
|
||||||
|
# False,
|
||||||
|
# True, # Activate RMSNorm
|
||||||
|
# )
|
||||||
|
# if res is None:
|
||||||
|
# res = hidden_states
|
||||||
|
|
||||||
|
# return normed_hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
@ -262,6 +282,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
# logger.info(f"query before rotary {query[:10, ..., :8]}")
|
||||||
|
# logger.info(f"cos before rotary {cos[:10]}")
|
||||||
|
# logger.info(f"sin before rotary {sin[:10]}")
|
||||||
|
# TODO: maybe we can use VLLM rotary here, which would require position_ids? Probably too big of a change...
|
||||||
|
# Flash Attention kernel may be usable since it is Triton-based
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
@ -272,6 +297,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
|
||||||
|
# logger.info(f"query {query.shape}")
|
||||||
|
# logger.info(f"query piece {query[:10, ..., :8]}")
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
@ -298,6 +326,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# logger.info(f"attn_output {attn_output.shape}")
|
||||||
|
# logger.info(f"attn_output piece {attn_output[:10, ..., :8]}")
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import dropout_layer_norm
|
# import dropout_layer_norm
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||||
@ -110,7 +110,7 @@ class MistralRMSNorm(nn.Module):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
# if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -126,29 +126,29 @@ class MistralRMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
return self.weight * hidden_states, residual
|
||||||
else:
|
# else:
|
||||||
# faster post attention rms norm
|
# # faster post attention rms norm
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
# hidden_states,
|
||||||
residual,
|
# residual,
|
||||||
self.weight,
|
# self.weight,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
0.0,
|
# 0.0,
|
||||||
self.variance_epsilon,
|
# self.variance_epsilon,
|
||||||
1.0,
|
# 1.0,
|
||||||
0,
|
# 0,
|
||||||
None,
|
# None,
|
||||||
False,
|
# False,
|
||||||
True, # Activate RMSNorm
|
# True, # Activate RMSNorm
|
||||||
)
|
# )
|
||||||
if res is None:
|
# if res is None:
|
||||||
res = hidden_states
|
# res = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, res
|
# return normed_hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -55,7 +55,7 @@ from text_generation_server.utils.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
import dropout_layer_norm
|
# import dropout_layer_norm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -354,7 +354,7 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
# if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -370,38 +370,38 @@ class IdeficsRMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
return self.weight * hidden_states
|
return self.weight * hidden_states
|
||||||
else:
|
# else:
|
||||||
# faster post attention rms norm
|
# # faster post attention rms norm
|
||||||
unwrap = False
|
# unwrap = False
|
||||||
if len(hidden_states.shape) > 2:
|
# if len(hidden_states.shape) > 2:
|
||||||
unwrap = True
|
# unwrap = True
|
||||||
shape = hidden_states.shape
|
# shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
# hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||||
|
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
# normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
# hidden_states,
|
||||||
residual,
|
# residual,
|
||||||
self.weight,
|
# self.weight,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
0.0,
|
# 0.0,
|
||||||
self.variance_epsilon,
|
# self.variance_epsilon,
|
||||||
1.0,
|
# 1.0,
|
||||||
0,
|
# 0,
|
||||||
None,
|
# None,
|
||||||
False,
|
# False,
|
||||||
True, # Activate RMSNorm
|
# True, # Activate RMSNorm
|
||||||
)
|
# )
|
||||||
if res is None:
|
# if res is None:
|
||||||
res = hidden_states
|
# res = hidden_states
|
||||||
|
|
||||||
if unwrap:
|
# if unwrap:
|
||||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
# normed_hidden_states = normed_hidden_states.view(*shape)
|
||||||
|
|
||||||
return normed_hidden_states
|
# return normed_hidden_states
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaMLP
|
# this was adapted from LlamaMLP
|
||||||
|
@ -3,6 +3,8 @@ import torch
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from .import_utils import is_cuda_system, is_rocm_system
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
@ -41,10 +43,17 @@ except ImportError as e:
|
|||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if not (is_sm75 or is_sm8x or is_sm90):
|
if is_cuda_system() and not (is_sm75 or is_sm8x or is_sm90):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
) from e
|
) from e
|
||||||
|
elif is_rocm_system():
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
if "MI210" not in torch.cuda.get_device_name(idx) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
|
)
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
HAS_FLASH_ATTN = True
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
@ -59,6 +68,7 @@ def attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
|
# logger.info(f"HAS_FLASH_ATTN_V2 {HAS_FLASH_ATTN_V2}")
|
||||||
if HAS_FLASH_ATTN_V2:
|
if HAS_FLASH_ATTN_V2:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
@ -79,6 +89,7 @@ def attention(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# logger.info(f"HAS_FLASH_ATTN {HAS_FLASH_ATTN}")
|
||||||
if HAS_FLASH_ATTN:
|
if HAS_FLASH_ATTN:
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -124,7 +135,8 @@ def attention(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
True,
|
||||||
False,
|
False, # is_deterministic => rocm specific argument
|
||||||
|
False, # return_softmax
|
||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
15
server/text_generation_server/utils/import_utils.py
Normal file
15
server/text_generation_server/utils/import_utils.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import subprocess
|
||||||
|
|
||||||
|
def is_cuda_system():
|
||||||
|
try:
|
||||||
|
subprocess.check_output("nvidia-smi")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_rocm_system():
|
||||||
|
try:
|
||||||
|
subprocess.check_output("rocm-smi")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
@ -509,50 +509,50 @@ class TensorParallelEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import dropout_layer_norm
|
# import dropout_layer_norm
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192:
|
# if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||||
else:
|
# else:
|
||||||
(
|
# (
|
||||||
normed_hidden_states,
|
# normed_hidden_states,
|
||||||
residual,
|
# residual,
|
||||||
*rest,
|
# *rest,
|
||||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
# ) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
hidden_states,
|
# hidden_states,
|
||||||
residual,
|
# residual,
|
||||||
self.weight,
|
# self.weight,
|
||||||
self.bias,
|
# self.bias,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
None,
|
# None,
|
||||||
0.0,
|
# 0.0,
|
||||||
self.eps,
|
# self.eps,
|
||||||
1.0,
|
# 1.0,
|
||||||
0,
|
# 0,
|
||||||
None,
|
# None,
|
||||||
False,
|
# False,
|
||||||
False,
|
# False,
|
||||||
)
|
# )
|
||||||
if residual is None:
|
# if residual is None:
|
||||||
residual = hidden_states
|
# residual = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
# return normed_hidden_states, residual
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
# from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
# import rotary_emb
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
@ -692,11 +692,19 @@ try:
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
x1 = x[..., :rotary_dim]
|
|
||||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
dtype = x.dtype
|
||||||
return x
|
x_upcast = x.to(torch.float32)
|
||||||
|
cos = cos.to(torch.float32)
|
||||||
|
sin = sin.to(torch.float32)
|
||||||
|
|
||||||
|
x1 = x_upcast[..., :rotary_dim]
|
||||||
|
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
# rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||||
|
# Flash Attention kernel casts everything to float, not sure why. In place op here
|
||||||
|
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||||
|
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
|
@ -4,6 +4,8 @@ import torch
|
|||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +56,7 @@ def attention(
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
|
logger.info(f"paged attention use_v1 {use_v1}")
|
||||||
if use_v1:
|
if use_v1:
|
||||||
attention_ops.paged_attention_v1(
|
attention_ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
|
Loading…
Reference in New Issue
Block a user