Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.
This commit is contained in:
Daniël de Kok 2024-09-26 13:35:27 +00:00
parent 7efcb5e0ed
commit bee5ee1f03
26 changed files with 138 additions and 68 deletions

1
Cargo.lock generated
View File

@ -4243,6 +4243,7 @@ dependencies = [
"hf-hub",
"nix 0.28.0",
"once_cell",
"pyo3",
"reqwest",
"serde",
"serde_json",

View File

@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release]
incremental = true

View File

@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
pyo3 = { workspace = true }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
thiserror = "1.0.59"

22
launcher/src/gpu.rs Normal file
View File

@ -0,0 +1,22 @@
use std::sync::LazyLock;
pub static COMPUTE_CAPABILITY: LazyLock<Option<(isize, isize)>> =
LazyLock::new(get_cuda_capability);
fn get_cuda_capability() -> Option<(isize, isize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok(capability) => Some(capability),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}

View File

@ -26,6 +26,7 @@ use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
mod gpu;
fn get_config(
model_id: &str,
@ -65,6 +66,7 @@ fn get_config(
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = *gpu::COMPUTE_CAPABILITY;
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
}
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
"paged"
} else {
"flashdecoding"
};
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
@ -89,10 +98,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
"Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
attention = Some(fallback_attention.to_string());
}
}
Some("t5") => {}
@ -101,8 +110,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some(fallback_attention.to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
@ -110,8 +119,17 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = if attention == "paged"
&& prefix_caching.is_none()
&& compute_capability.is_some()
{
tracing::info!("Disabling prefix caching because it is not supported with 'flashinfer'");
"false".to_string()
} else {
prefix_caching.unwrap_or("true".to_string())
};
(prefix_caching, attention)
}

View File

@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] }
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
pyo3 = { workspace = true }
[build-dependencies]

View File

@ -11,11 +11,24 @@ if SYSTEM == "cuda":
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
@ -24,6 +37,7 @@ __all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"Seqlen",
]

View File

@ -287,16 +287,14 @@ elif V2:
else:
def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=None,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
@ -338,16 +336,22 @@ else:
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
True,
causal,
False,
0,
None,
)
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(

View File

@ -13,6 +13,9 @@ _PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
try:
from vllm._C import cache_ops
except Exception as e:
@ -156,7 +159,6 @@ if ENGINE != "triton":
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
FastLinear,
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -25,7 +25,6 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -39,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -25,12 +25,12 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -27,6 +27,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -41,6 +41,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -39,10 +39,10 @@ from text_generation_server.layers.attention import (
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight
@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -26,7 +26,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -19,13 +19,13 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig):
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -17,11 +17,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights):
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -5,7 +5,6 @@ import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
@ -13,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,

View File

@ -18,11 +18,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa(
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig):
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,