Revamped all this architecture.

This commit is contained in:
Nicolas Patry 2024-05-29 17:36:04 +00:00
parent 7890cd66f7
commit daddd2e90b
21 changed files with 691 additions and 420 deletions

View File

@ -0,0 +1,13 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "xpu":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -2,18 +2,17 @@ import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
else:
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
@ -23,22 +22,17 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention(
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
@ -72,21 +66,6 @@ def attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
if SYSTEM == "xpu":
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
@ -174,3 +153,132 @@ def attention(
"auto",
1.0,
)
try:
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if V2:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
else:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)

View File

@ -0,0 +1,336 @@
import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
from .flash_attn_triton import triton_attention
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
out,
cu_seqlen_q,
cu_seqlen_k,
None,
block_tables,
None,
max_q,
max_k,
0.0,
softmax_scale,
False,
True,
-1,
0,
False,
None,
)
else:
from vllm._C import ops
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
try:
import flash_attn_2_cuda
if ENGINE == "triton":
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
elif ENGINE == "ck":
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
except ImportError:
try:
import flash_attn_cuda
ENGINE = "v1"
logger.info("ROCm: using Flash Attention 1")
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with Rocm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = ENGINE != "v1"
if ENGINE == "ck":
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
elif ENGINE == "triton":
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
else:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)

View File

@ -0,0 +1,80 @@
import intel_extension_for_pytorch as ipex
import torch
SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
query = query.contiguous()
block_size = value_cache.shape[3]
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
cu_seqlen_q,
block_size,
max_s,
None,
)

View File

@ -80,15 +80,11 @@ try:
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_CUDA,
HAS_FLASH_ATTN_V2_ROCM,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if FLASH_ATTENTION:
__all__.append(FlashGPT2)
@ -262,6 +258,7 @@ def get_model(
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
global FLASH_ATTENTION
if dtype is None:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
@ -412,6 +409,12 @@ def get_model(
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
)
FLASH_ATTENTION = False
if model_type == MAMBA:
return Mamba(
@ -699,11 +702,7 @@ def get_model(
if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashMistral(
model_id,
revision,
@ -726,11 +725,7 @@ def get_model(
if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
revision,
@ -753,11 +748,7 @@ def get_model(
if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
revision,
@ -781,11 +772,7 @@ def get_model(
if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
return FlashQwen2(
model_id,
revision,

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import paged_attention, attention
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
@ -291,7 +291,7 @@ class FlashCohereAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
key,
value,
@ -302,7 +302,7 @@ class FlashCohereAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -25,7 +25,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
key,
value,
@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -28,7 +28,11 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -147,9 +151,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -157,7 +159,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -168,7 +170,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
paged_attention.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],
@ -175,6 +176,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -5,7 +5,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -15,7 +15,11 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
)
def load_row(config, prefix: str, weights, bias: bool):
@ -314,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
paged_attention.reshape_and_cache(
reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
@ -328,7 +332,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
@ -339,7 +343,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
paged_attention.reshape_and_cache(
reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -1,293 +0,0 @@
import os
import torch
from loguru import logger
import math
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm":
if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif SYSTEM == "rocm":
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN_V2_CUDA:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
elif HAS_FLASH_ATTN:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
else:
raise NotImplementedError("flash attention is not installed")

View File

@ -1,4 +1,5 @@
import torch
from loguru import logger
def is_xpu_available():
@ -40,7 +41,7 @@ elif is_xpu_available():
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
SYSTEM = "cpu"
SYSTEM = "ipex"
def noop(*args, **kwargs):
pass
@ -48,3 +49,4 @@ else:
empty_cache = noop
synchronize = noop
get_free_memory = noop
logger.info(f"Detected system {SYSTEM}")