mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-08 10:52:14 +00:00
addressed review comments
This commit is contained in:
parent
64e981fdcf
commit
829144d15a
@ -67,14 +67,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3.11-dev \
|
||||
python3.11-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
@ -82,11 +79,6 @@ ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
|
||||
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
|
||||
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
|
||||
&& rm cmake-3.30.2-linux-x86_64.sh
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
@ -111,7 +103,7 @@ RUN case ${TARGETPLATFORM} in \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja joblib msgpack --no-cache-dir
|
||||
RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir
|
||||
|
||||
# Install HIPBLASLt
|
||||
ARG HIPBLASLT_BRANCH="6f65c6e"
|
||||
@ -129,7 +121,8 @@ RUN dpkg -i hipBLASLt/build/release/*.deb \
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
pip install .
|
||||
pip install . && \
|
||||
rm -r triton
|
||||
|
||||
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
|
||||
@ -153,6 +146,7 @@ ARG BUILD_CAFFE2="0" \
|
||||
USE_MEM_EFF_ATTENTION="0"
|
||||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
RUN rm -rf pytorch
|
||||
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
@ -13,9 +13,19 @@ if SYSTEM == "cuda":
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .rocm import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||
from .ipex import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
|
@ -351,3 +351,9 @@ else:
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
|
@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
|
||||
def attention(
|
||||
|
@ -16,9 +16,18 @@ _PARTITION_SIZE_CUSTOM = 256
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
if custom_attn_available:
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
if use_rocm_custom_paged_attn:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
except ImportError as e:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Custom Paged Attention not available. Complete error: {e}",
|
||||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
@ -71,6 +80,9 @@ def paged_attention(
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
@ -78,7 +90,7 @@ def paged_attention(
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
custom_attn_available
|
||||
use_rocm_custom_paged_attn
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
@ -224,10 +236,10 @@ if ENGINE == "ck":
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -268,11 +280,14 @@ elif ENGINE == "triton":
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -50,144 +50,3 @@ def grouped_topk(
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
) -> Dict[str, int]:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
import triton.language as tl
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_moe_configs,
|
||||
invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
)
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(
|
||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config["BLOCK_SIZE_M"], E
|
||||
)
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
if inplace:
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states,
|
||||
)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -298,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -337,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
@ -333,8 +333,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -193,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else key,
|
||||
kv_cache[1] if PAGED_KV else value,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -28,7 +28,7 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
@ -221,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -219,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -274,8 +274,8 @@ class MixtralAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0] if PAGED_KV else qkv[:, 1],
|
||||
kv_cache[1] if PAGED_KV else qkv[:, 2],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else key_value[:, 0],
|
||||
kv_cache[1] if PAGED_KV else key_value[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.models.globals import PAGED_KV
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -4,7 +4,6 @@ from loguru import logger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
@ -53,12 +52,6 @@ CUDA_GRAPHS = cuda_graphs
|
||||
# index in all cases.
|
||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||
|
||||
PAGED_KV: bool
|
||||
if SYSTEM in {"rocm", "ipex"}:
|
||||
PAGED_KV = False
|
||||
else:
|
||||
PAGED_KV = True
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
global ADAPTER_TO_INDEX
|
||||
|
Loading…
Reference in New Issue
Block a user