mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Revamped all this architecture.
This commit is contained in:
parent
7890cd66f7
commit
daddd2e90b
13
server/text_generation_server/layers/attention/__init__.py
Normal file
13
server/text_generation_server/layers/attention/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
elif SYSTEM == "xpu":
|
||||||
|
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
else:
|
||||||
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
@ -2,18 +2,17 @@ import torch
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
from vllm._C import cache_ops
|
||||||
else:
|
from vllm._C import ops
|
||||||
try:
|
except Exception as e:
|
||||||
from vllm._C import cache_ops
|
raise ImportError(
|
||||||
from vllm._C import ops
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
except Exception as e:
|
)
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
@ -23,22 +22,17 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if SYSTEM == "xpu":
|
if FLASH_DECODING:
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
shape = key_cache.shape
|
||||||
key, value, key_cache, value_cache, slots
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
)
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
if FLASH_DECODING:
|
cache_ops.reshape_and_cache(
|
||||||
shape = key_cache.shape
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
)
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def paged_attention(
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
@ -72,21 +66,6 @@ def attention(
|
|||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = cu_seqlen_k
|
input_lengths = cu_seqlen_k
|
||||||
if SYSTEM == "xpu":
|
|
||||||
query = query.contiguous()
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
@ -174,3 +153,132 @@ def attention(
|
|||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
V2 = True
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
V2 = False
|
||||||
|
except ImportError as e:
|
||||||
|
if major >= 8:
|
||||||
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
elif is_sm75:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = V2
|
||||||
|
if V2:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
336
server/text_generation_server/layers/attention/rocm.py
Normal file
336
server/text_generation_server/layers/attention/rocm.py
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
from vllm._C import ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
if FLASH_DECODING:
|
||||||
|
shape = key_cache.shape
|
||||||
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
|
else:
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlen_q: torch.Tensor,
|
||||||
|
cu_seqlen_k: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
input_lengths = cu_seqlen_k
|
||||||
|
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
if FLASH_DECODING:
|
||||||
|
max_q = 1
|
||||||
|
max_k = max_s
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
out,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
block_tables,
|
||||||
|
None,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
|
use_v1 = max_s <= 8192 and (
|
||||||
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||||
|
)
|
||||||
|
if use_v1:
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=out.dtype,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
if ENGINE == "triton":
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||||
|
elif ENGINE == "ck":
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
ENGINE = "v1"
|
||||||
|
logger.info("ROCm: using Flash Attention 1")
|
||||||
|
except ImportError as e:
|
||||||
|
if major >= 8:
|
||||||
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
elif is_sm75:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
name = torch.cuda.get_device_name(idx)
|
||||||
|
if "MI210" not in name and "MI250" not in name:
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
|
)
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU with Rocm capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = ENGINE != "v1"
|
||||||
|
if ENGINE == "ck":
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif ENGINE == "triton":
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
output, _ = triton_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
causal,
|
||||||
|
softmax_scale,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
80
server/text_generation_server/layers/attention/xpu.py
Normal file
80
server/text_generation_server/layers/attention/xpu.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
import torch
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return ipex.llm.functional.varlen_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlen_q: torch.Tensor,
|
||||||
|
cu_seqlen_k: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
query = query.contiguous()
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
cu_seqlen_q,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
@ -80,15 +80,11 @@ try:
|
|||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
from text_generation_server.utils.flash_attn import (
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
HAS_FLASH_ATTN_V2_CUDA,
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashGPT2)
|
__all__.append(FlashGPT2)
|
||||||
@ -262,6 +258,7 @@ def get_model(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
global FLASH_ATTENTION
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq"]:
|
if quantize in ["awq", "exl2", "gptq"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
@ -412,6 +409,12 @@ def get_model(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Sharding is currently not supported with `exl2` quantization"
|
"Sharding is currently not supported with `exl2` quantization"
|
||||||
)
|
)
|
||||||
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
||||||
|
logger.warning(
|
||||||
|
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
|
||||||
|
)
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
@ -699,11 +702,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -726,11 +725,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -753,11 +748,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashStarcoder2(
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -781,11 +772,7 @@ def get_model(
|
|||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashQwen2(
|
return FlashQwen2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -25,7 +25,7 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import paged_attention, attention
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
@ -291,7 +291,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -302,7 +302,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||||||
if SYSTEM != "xpu":
|
if SYSTEM != "xpu":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -25,7 +25,11 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -28,7 +28,11 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
from text_generation_server.models.globals import FLASH_DECODING
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -147,9 +151,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -157,7 +159,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -168,7 +170,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
|
|||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
from text_generation_server.utils.flash_attn import attention
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -175,6 +176,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -5,7 +5,11 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -15,7 +15,11 @@ from text_generation_server.layers import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils import flash_attn, paged_attention
|
from text_generation_server.layers.attention import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
@ -314,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
@ -328,7 +332,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
@ -339,7 +343,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -1,293 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
import math
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
|
||||||
from text_generation_server.utils.flash_attn_triton import triton_attention
|
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = False
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM in {"cuda", "rocm"}:
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
is_sm94 = major == 9 and minor == 4
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
if (
|
|
||||||
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
|
||||||
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
|
||||||
):
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
|
||||||
else:
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
|
||||||
logger.info(
|
|
||||||
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
except ImportError:
|
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
||||||
)
|
|
||||||
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
|
||||||
except ImportError as e:
|
|
||||||
try:
|
|
||||||
import flash_attn_cuda
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
if "MI210" not in torch.cuda.get_device_name(
|
|
||||||
idx
|
|
||||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return ipex.llm.functional.varlen_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_CUDA:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
output, _ = triton_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
causal,
|
|
||||||
softmax_scale,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"window_size_left is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
return flash_attn_cuda.fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("flash attention is not installed")
|
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
def is_xpu_available():
|
def is_xpu_available():
|
||||||
@ -40,7 +41,7 @@ elif is_xpu_available():
|
|||||||
synchronize = torch.xpu.synchronize
|
synchronize = torch.xpu.synchronize
|
||||||
get_free_memory = get_xpu_free_memory
|
get_free_memory = get_xpu_free_memory
|
||||||
else:
|
else:
|
||||||
SYSTEM = "cpu"
|
SYSTEM = "ipex"
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -48,3 +49,4 @@ else:
|
|||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = noop
|
get_free_memory = noop
|
||||||
|
logger.info(f"Detected system {SYSTEM}")
|
||||||
|
Loading…
Reference in New Issue
Block a user