addressed review comments

This commit is contained in:
Mohit Sharma 2024-09-27 10:28:37 +00:00
parent 64e981fdcf
commit 829144d15a
23 changed files with 101 additions and 223 deletions

View File

@ -67,14 +67,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
hipsolver-dev \ hipsolver-dev \
rccl-dev \ rccl-dev \
cmake \ cmake \
python3.11-dev \
python3.11-venv && \ python3.11-venv && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1 ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
@ -82,11 +79,6 @@ ENV PATH=/opt/conda/bin:$PATH
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
&& rm cmake-3.30.2-linux-x86_64.sh
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba # Install mamba
# translating Docker's TARGETPLATFORM into mamba arches # translating Docker's TARGETPLATFORM into mamba arches
@ -111,7 +103,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN pip install numpy einops ninja joblib msgpack --no-cache-dir RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir
# Install HIPBLASLt # Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e" ARG HIPBLASLT_BRANCH="6f65c6e"
@ -129,7 +121,8 @@ RUN dpkg -i hipBLASLt/build/release/*.deb \
RUN pip uninstall -y triton && \ RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \ cd triton/python && \
pip install . pip install . && \
rm -r triton
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
@ -153,6 +146,7 @@ ARG BUILD_CAFFE2="0" \
USE_MEM_EFF_ATTENTION="0" USE_MEM_EFF_ATTENTION="0"
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
RUN rm -rf pytorch
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1 ENV HIP_FORCE_DEV_KERNARG=1

View File

@ -13,9 +13,19 @@ if SYSTEM == "cuda":
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -351,3 +351,9 @@ else:
None, None,
) )
return out return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention( def attention(

View File

@ -16,9 +16,18 @@ _PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" PREFILL_IN_KV_CACHE = False
if custom_attn_available:
from vllm._custom_C import paged_attention_custom use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
)
use_rocm_custom_paged_attn = False
try: try:
import vllm._custom_ops as ops import vllm._custom_ops as ops
@ -71,6 +80,9 @@ def paged_attention(
# limitations under the License. # limitations under the License.
# #
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
@ -78,7 +90,7 @@ def paged_attention(
num_kv_heads = key_cache.shape[1] num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads gqa_ratio = num_heads // num_kv_heads
use_custom = ( use_custom = (
custom_attn_available use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16) and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64) and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32) and (block_size == 16 or block_size == 32)
@ -224,10 +236,10 @@ if ENGINE == "ck":
value_cache: torch.Tensor, value_cache: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale, softmax_scale: float,
window_size_left=-1, window_size_left: int = -1,
causal=True, causal: bool = True,
softcap=0.0, softcap: float = 0.0,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -268,11 +280,14 @@ elif ENGINE == "triton":
value_cache: torch.Tensor, value_cache: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale, softmax_scale: float,
window_size_left=-1, window_size_left: int = -1,
causal=True, causal: bool = True,
softcap=0.0, softcap: float = 0.0,
): ):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q) out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Dict, Any from typing import Tuple
import torch import torch
import torch.distributed import torch.distributed
@ -50,144 +50,3 @@ def grouped_topk(
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_moe_configs,
invoke_fused_moe_kernel,
moe_align_block_size,
)
M, _ = hidden_states.shape
E, N, _ = w1.shape
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -298,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PAGED_KV else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -337,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -15,7 +15,7 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm": if SYSTEM == "rocm":
@ -333,8 +333,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PAGED_KV else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PAGED_KV else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -193,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PAGED_KV else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -28,7 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -221,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -219,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -274,8 +274,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if PAGED_KV else qkv[:, 1], kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PAGED_KV else qkv[:, 2], kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,4 +1,4 @@
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,4 +1,4 @@
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PAGED_KV else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(), kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(), kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,4 +1,4 @@
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else key_value[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PAGED_KV else key_value[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
import torch import torch
import torch.distributed import torch.distributed
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -4,7 +4,6 @@ from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
@ -53,12 +52,6 @@ CUDA_GRAPHS = cuda_graphs
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
PAGED_KV: bool
if SYSTEM in {"rocm", "ipex"}:
PAGED_KV = False
else:
PAGED_KV = True
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX