diff --git a/Dockerfile_amd b/Dockerfile_amd index dabcb77a..766881a8 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -67,14 +67,11 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins hipsolver-dev \ rccl-dev \ cmake \ - python3.11-dev \ python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 -ARG PYTORCH_VERSION='2.3.0' -ARG ROCM_VERSION='6.0.2' ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM @@ -82,11 +79,6 @@ ENV PATH=/opt/conda/bin:$PATH ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" -RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \ - && chmod +x cmake-3.30.2-linux-x86_64.sh \ - && ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \ - && rm cmake-3.30.2-linux-x86_64.sh - # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba # translating Docker's TARGETPLATFORM into mamba arches @@ -111,7 +103,7 @@ RUN case ${TARGETPLATFORM} in \ /opt/conda/bin/conda clean -ya # Install flash-attention, torch dependencies -RUN pip install numpy einops ninja joblib msgpack --no-cache-dir +RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir # Install HIPBLASLt ARG HIPBLASLT_BRANCH="6f65c6e" @@ -129,7 +121,8 @@ RUN dpkg -i hipBLASLt/build/release/*.deb \ RUN pip uninstall -y triton && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ cd triton/python && \ - pip install . + pip install . && \ + rm -r triton ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27" RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \ @@ -153,6 +146,7 @@ ARG BUILD_CAFFE2="0" \ USE_MEM_EFF_ATTENTION="0" RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install +RUN rm -rf pytorch # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 56fc5319..2134d857 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -13,9 +13,19 @@ if SYSTEM == "cuda": SUPPORTS_WINDOWING, ) elif SYSTEM == "rocm": - from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .rocm import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "ipex": - from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .ipex import ( + attention, + paged_attention, + reshape_and_cache, + SUPPORTS_WINDOWING, + ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 4b588b5c..6c645770 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -351,3 +351,9 @@ else: None, ) return out + + +# Prefill in the cache with every kind of attention, unless we +# have a configuration that requires flash-attention v1, which +# does not support block tables. +PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index d0eadc75..657c90af 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen from typing import Optional SUPPORTS_WINDOWING = False +PREFILL_IN_KV_CACHE = False def attention( diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 0835cb97..be6158c1 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -16,9 +16,18 @@ _PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" -if custom_attn_available: - from vllm._custom_C import paged_attention_custom +PREFILL_IN_KV_CACHE = False + +use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" +try: + if use_rocm_custom_paged_attn: + from vllm._custom_C import paged_attention_custom +except ImportError as e: + log_master( + logger.info, + f"Custom Paged Attention not available. Complete error: {e}", + ) + use_rocm_custom_paged_attn = False try: import vllm._custom_ops as ops @@ -71,6 +80,9 @@ def paged_attention( # limitations under the License. # + if softcap is not None: + raise RuntimeError("Paged attention doesn't support softcapping") + # value_cache => [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape @@ -78,7 +90,7 @@ def paged_attention( num_kv_heads = key_cache.shape[1] gqa_ratio = num_heads // num_kv_heads use_custom = ( - custom_attn_available + use_rocm_custom_paged_attn and (query.dtype == torch.half or query.dtype == torch.bfloat16) and (head_size == 128 or head_size == 64) and (block_size == 16 or block_size == 32) @@ -224,10 +236,10 @@ if ENGINE == "ck": value_cache: torch.Tensor, seqlen: Seqlen, block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -268,11 +280,14 @@ elif ENGINE == "triton": value_cache: torch.Tensor, seqlen: Seqlen, block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: float = 0.0, ): + if softcap is not None: + raise NotImplementedError("softcap is only available with CK flash attn") + out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. diff --git a/server/text_generation_server/layers/moe/fused_moe_rocm.py b/server/text_generation_server/layers/moe/fused_moe_rocm.py index ab30ff53..68accb99 100644 --- a/server/text_generation_server/layers/moe/fused_moe_rocm.py +++ b/server/text_generation_server/layers/moe/fused_moe_rocm.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Dict, Any +from typing import Tuple import torch import torch.distributed @@ -50,144 +50,3 @@ def grouped_topk( topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -): - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - - import triton.language as tl - from vllm import _custom_ops as ops - from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_moe_configs, - invoke_fused_moe_kernel, - moe_align_block_size, - ) - - M, _ = hidden_states.shape - E, N, _ = w1.shape - - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) - - if inplace: - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states, - ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index b0e57d68..44db0290 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -298,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 8bce4e57..852e52d8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -337,8 +337,8 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 94c7600a..97a26930 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -15,7 +15,7 @@ from typing import List, Optional, Tuple -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "rocm": @@ -333,8 +333,8 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 1ad88801..b1f0dba2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index a401798a..3ddcba8a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 33f20b9a..d47bb104 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index f2197069..200735c6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -193,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key, - kv_cache[1] if PAGED_KV else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6be89297..a77ec234 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,7 +28,7 @@ from torch import nn from transformers.activations import ACT2FN from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, @@ -221,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3b56bbab..d0503277 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -219,8 +219,8 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index abfa737a..3eb81daf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -274,8 +274,8 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2d3be430..471abca3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if PAGED_KV else qkv[:, 1], - kv_cache[1] if PAGED_KV else qkv[:, 2], + kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], + kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 76e406a7..4a18090a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 0f0dbf5e..00e63a6c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index ba516881..2cf243e8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,6 +1,6 @@ from typing import List, Optional, Tuple -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed from torch import nn @@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, 0], - kv_cache[1] if PAGED_KV else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(), - kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(), + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index fa074606..0c1518e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,4 +1,4 @@ -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else key_value[:, 0], - kv_cache[1] if PAGED_KV else key_value[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 30d35632..22ac0240 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from text_generation_server.models.globals import PAGED_KV +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE import torch import torch.distributed @@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], - kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index f04c6df5..6c518c2c 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,6 @@ from loguru import logger from typing import Dict, Optional from text_generation_server.utils.log import log_master -from text_generation_server.utils.import_utils import SYSTEM PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") @@ -53,12 +52,6 @@ CUDA_GRAPHS = cuda_graphs # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None -PAGED_KV: bool -if SYSTEM in {"rocm", "ipex"}: - PAGED_KV = False -else: - PAGED_KV = True - def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX