add skinny kernel and merge fixes

This commit is contained in:
Mohit Sharma 2024-09-12 13:16:13 +00:00
parent 058162685f
commit 59fd0cbdff
23 changed files with 121 additions and 101 deletions

View File

@ -152,9 +152,6 @@ ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive. # However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
FROM base AS kernel-builder FROM base AS kernel-builder
@ -245,6 +242,13 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base-copy FROM base-copy
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh

View File

@ -45,7 +45,6 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py

View File

@ -62,7 +62,6 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(query) out = torch.empty_like(query)

View File

@ -50,9 +50,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -76,6 +75,7 @@ def paged_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads gqa_ratio = num_heads // num_kv_heads
use_custom = ( use_custom = (
custom_attn_available custom_attn_available
@ -92,7 +92,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths input_lengths = seqlen.input_lengths
out = torch.empty_like(query) out = torch.empty_like(query)
@ -220,10 +220,10 @@ if ENGINE == "ck":
def attention( def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
@ -237,17 +237,17 @@ if ENGINE == "ck":
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
None, None,
None, None,
None, None,
None, None,
max_s, seqlen.max_q,
max_s, seqlen.max_k,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
@ -264,26 +264,27 @@ elif ENGINE == "triton":
def attention( def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=0.0,
): ):
out = torch.empty_like(q) out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_k,
causal, causal,
softmax_scale, softmax_scale,
) )

View File

@ -1,12 +1,19 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
import os
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
from vllm import _custom_C "true",
except Exception as e: "1",
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") )
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
else: else:
self.bias = None self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
@ -62,9 +77,9 @@ class FastLinearROCm(torch.nn.Module):
bias = self.bias bias = self.bias
if ( if (
SYSTEM == "rocm" self.use_skinny_gemm
and inp.numel() // inp.shape[-1] == 1
and inp.dtype == torch.float16 and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
): ):
batched = False batched = False
inp_shape = inp.shape inp_shape = inp.shape
@ -73,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
inp = inp.view(-1, inp_shape[-1]) inp = inp.view(-1, inp_shape[-1])
batched = True batched = True
m, k = weight.shape[0], inp_shape[1] m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
out = torch.empty( if m > 8 and n <= 4:
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" out = torch.empty(
) inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): )
_custom_C.LLMM1(weight, inp, out, 8) _custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0: elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4) _custom_C.LLMM1(weight, inp, out, 4)
else: else:
out = F.linear(inp, weight) out = F.linear(inp, weight)

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -314,7 +315,6 @@ class FlashCohereAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj( return self.o_proj(

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -353,7 +354,6 @@ class DbrxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -363,8 +364,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -380,7 +381,6 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
# Remove padding. # Remove padding.

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -257,7 +257,6 @@ class FlashGemma2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
softcap=self.softcap, softcap=self.softcap,
) )

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -249,7 +249,6 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -248,7 +248,6 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import PAGED_KV
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -237,7 +238,6 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj( return self.o_proj(

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -236,7 +237,6 @@ class MistralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj( return self.o_proj(

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -275,8 +276,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -293,7 +294,6 @@ class MixtralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], kv_cache[0] if PAGED_KV else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], kv_cache[1] if PAGED_KV else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -189,7 +189,6 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -211,7 +211,6 @@ class FlashPhiAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -155,7 +155,6 @@ class Qwen2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1,11 +1,11 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -224,7 +224,6 @@ class FlashRWAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -326,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -343,7 +342,6 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.dense( return self.dense(

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], kv_cache[0] if PAGED_KV else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], kv_cache[1] if PAGED_KV else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -310,7 +310,6 @@ class FlashMQAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -260,7 +260,6 @@ class Starcoder2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
self.num_key_value_heads,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1379,6 +1379,7 @@ class FlashCausalLM(Model):
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, prefix_lengths=prefix_lens_tensor,
@ -1396,7 +1397,7 @@ class FlashCausalLM(Model):
block_tables=None, block_tables=None,
seqlen=seqlen, seqlen=seqlen,
slots=slots, slots=slots,
max_s=seqlen, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None, prefill_cache_indices=None,
) )

View File

@ -4,6 +4,7 @@ from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
PAGED_KV: bool
if SYSTEM in {"rocm", "ipex"}:
PAGED_KV = False
else:
PAGED_KV = True
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX