From 59fd0cbdff68c20d954aafdecd419b4152f34e5e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 12 Sep 2024 13:16:13 +0000 Subject: [PATCH] add skinny kernel and merge fixes --- Dockerfile_amd | 10 ++-- .../layers/attention/cuda.py | 1 - .../layers/attention/ipex.py | 1 - .../layers/attention/rocm.py | 47 ++++++++++--------- .../text_generation_server/layers/linear.py | 44 ++++++++++++----- .../custom_modeling/flash_cohere_modeling.py | 6 +-- .../custom_modeling/flash_dbrx_modeling.py | 6 +-- .../flash_deepseek_v2_modeling.py | 6 +-- .../custom_modeling/flash_gemma2_modeling.py | 7 ++- .../custom_modeling/flash_gemma_modeling.py | 7 ++- .../custom_modeling/flash_gpt2_modeling.py | 7 ++- .../custom_modeling/flash_gptj_modeling.py | 5 +- .../custom_modeling/flash_llama_modeling.py | 6 +-- .../custom_modeling/flash_mistral_modeling.py | 6 +-- .../custom_modeling/flash_mixtral_modeling.py | 6 +-- .../custom_modeling/flash_neox_modeling.py | 7 ++- .../custom_modeling/flash_phi_modeling.py | 7 ++- .../custom_modeling/flash_qwen2_modeling.py | 7 ++- .../custom_modeling/flash_rw_modeling.py | 12 ++--- .../flash_santacoder_modeling.py | 7 ++- .../flash_starcoder2_modeling.py | 7 ++- .../models/flash_causal_lm.py | 3 +- .../text_generation_server/models/globals.py | 7 +++ 23 files changed, 121 insertions(+), 101 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 1940b9851..2aa2a6bc4 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 -ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 -ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 -ENV VLLM_MOE_PADDING=0 FROM base AS kernel-builder @@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy +ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 +ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 +ENV VLLM_MOE_PADDING=0 +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV ROCM_USE_SKINNY_GEMM=1 + COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 592350f4c..4b588b5cf 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -45,7 +45,6 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 83254598a..2d1427ae6 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -62,7 +62,6 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): out = torch.empty_like(query) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 3e003acbd..0835cb972 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -50,9 +50,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, - num_kv_heads: int, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -76,6 +75,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + num_kv_heads = key_cache.shape[1] gqa_ratio = num_heads // num_kv_heads use_custom = ( custom_attn_available @@ -92,7 +92,7 @@ def paged_attention( _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = input_lengths.input_lengths + input_lengths = seqlen.input_lengths out = torch.empty_like(query) @@ -220,10 +220,10 @@ if ENGINE == "ck": def attention( q, - k, - v, - cu_seqlens, - max_s, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -237,17 +237,17 @@ if ENGINE == "ck": # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, None, None, None, None, - max_s, - max_s, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, @@ -264,26 +264,27 @@ elif ENGINE == "triton": def attention( q, - k, - v, - cu_seqlens, - max_s, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, + softcap=0.0, ): out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, causal, softmax_scale, ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 78815d744..69b6294bb 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,12 +1,19 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from torch.nn import functional as F +import os if SYSTEM == "rocm": - try: - from vllm import _custom_C - except Exception as e: - raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( + "true", + "1", + ) + + if ROCM_USE_SKINNY_GEMM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") class FastLinear(torch.nn.Module): @@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module): else: self.bias = None + self.cu_count = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + self.use_skinny_gemm = ( + ROCM_USE_SKINNY_GEMM + and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName + ) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") @@ -62,9 +77,9 @@ class FastLinearROCm(torch.nn.Module): bias = self.bias if ( - SYSTEM == "rocm" - and inp.numel() // inp.shape[-1] == 1 + self.use_skinny_gemm and inp.dtype == torch.float16 + and inp.shape[-1] % 8 == 0 ): batched = False inp_shape = inp.shape @@ -73,13 +88,16 @@ class FastLinearROCm(torch.nn.Module): inp = inp.view(-1, inp_shape[-1]) batched = True - m, k = weight.shape[0], inp_shape[1] - out = torch.empty( - inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" - ) - if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weight, inp, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + m, n, k = weight.shape[0], inp_shape[0], inp_shape[1] + if m > 8 and n <= 4: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) + _custom_C.wvSpltK(weight, inp, out, n, self.cu_count) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty( + inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device + ) _custom_C.LLMM1(weight, inp, out, 4) else: out = F.linear(inp, weight) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 8f6cba350..b0e57d686 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -314,7 +315,6 @@ class FlashCohereAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 478b5b162..8bce4e573 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -353,7 +354,6 @@ class DbrxAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0aa948e75..561363816 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from text_generation_server.layers import ( @@ -363,8 +364,8 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -380,7 +381,6 @@ class DeepseekV2Attention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) # Remove padding. diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 6bd4aac5c..1ad88801b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -257,7 +257,6 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, softcap=self.softcap, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index c253e6abe..a401798a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -249,7 +249,6 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 90382583e..33f20b9a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -18,13 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, @@ -248,7 +248,6 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index ef071d46d..f21970692 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PAGED_KV else key, + kv_cache[1] if PAGED_KV else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 39218531f..6be892970 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,6 +28,7 @@ from torch import nn from transformers.activations import ACT2FN from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import PAGED_KV from text_generation_server.layers.attention import ( paged_attention, attention, @@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -237,7 +238,6 @@ class FlashLlamaAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index dacda1018..3b56bbab0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -236,7 +237,6 @@ class MistralAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index b35688ec8..3451158bf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -275,8 +276,8 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -293,7 +294,6 @@ class MixtralAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 698f0343e..2d3be430b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -26,7 +27,6 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], - kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], + kv_cache[0] if PAGED_KV else qkv[:, 1], + kv_cache[1] if PAGED_KV else qkv[:, 2], seqlen, block_tables, self.softmax_scale, @@ -189,7 +189,6 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 38e8c8841..76e406a74 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) -from text_generation_server.utils.import_utils import SYSTEM class PhiConfig(PretrainedConfig): @@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -211,7 +211,6 @@ class FlashPhiAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index d43401bad..0f0dbf5ec 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix, weights): @@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -155,7 +155,6 @@ class Qwen2Attention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 765cf39eb..ba5168810 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,11 +1,11 @@ from typing import List, Optional, Tuple +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( SpeculativeHead, TensorParallelColumnLinear, @@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PAGED_KV else kv[:, 0], + kv_cache[1] if PAGED_KV else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -224,7 +224,6 @@ class FlashRWAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -326,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), - kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), + kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(), + kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, @@ -343,7 +342,6 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 80c280c87..fa0746066 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,3 +1,4 @@ +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_multi_mqa( @@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], - kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], + kv_cache[0] if PAGED_KV else key_value[:, 0], + kv_cache[1] if PAGED_KV else key_value[:, 1], seqlen, block_tables, self.softmax_scale, @@ -310,7 +310,6 @@ class FlashMQAttention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 0c4ce05ae..30d356324 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation_server.models.globals import PAGED_KV import torch import torch.distributed @@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight -from text_generation_server.utils.import_utils import SYSTEM class Starcoder2Config(PretrainedConfig): @@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PAGED_KV else kv_to_cache[:, 0], + kv_cache[1] if PAGED_KV else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -260,7 +260,6 @@ class Starcoder2Attention(torch.nn.Module): block_tables, seqlen, max_s, - self.num_key_value_heads, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6f21e8464..c4bf8a57e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1379,6 +1379,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = torch.tensor( [0, seqlen], device=self.device, dtype=torch.int32 ) + max_s = seqlen seqlen = Seqlen( input_lengths=input_lengths, prefix_lengths=prefix_lens_tensor, @@ -1396,7 +1397,7 @@ class FlashCausalLM(Model): block_tables=None, seqlen=seqlen, slots=slots, - max_s=seqlen, + max_s=max_s, lm_head_indices=None, prefill_cache_indices=None, ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 6c518c2ca..f04c6df52 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,6 +4,7 @@ from loguru import logger from typing import Dict, Optional from text_generation_server.utils.log import log_master +from text_generation_server.utils.import_utils import SYSTEM PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") @@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs # index in all cases. ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None +PAGED_KV: bool +if SYSTEM in {"rocm", "ipex"}: + PAGED_KV = False +else: + PAGED_KV = True + def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX