Use attention kernels from the Hub

This commit is contained in:
Daniël de Kok 2025-01-31 15:36:15 +00:00
parent 758ff3c598
commit b267caa537
6 changed files with 264 additions and 20 deletions

View File

@ -3,7 +3,6 @@
buildPythonPackage,
poetry-core,
mypy-protobuf,
attention-kernels,
awq-inference-engine,
causal-conv1d,
compressed-tensors,
@ -23,7 +22,6 @@
hf-transfer,
loguru,
mamba-ssm,
moe-kernels,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
@ -78,7 +76,6 @@ buildPythonPackage {
pythonRemoveDeps = [ "scipy" ];
dependencies = [
attention-kernels
awq-inference-engine
eetq
causal-conv1d

View File

@ -1,4 +1,250 @@
[
{
"repo_id": "kernels-community/attention",
"sha": "20100e6a97f0fa1465560aa21eecbf4b04d3d93a",
"files": [
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so",
"blob_id": "29733cfb726d11a1d278fb0f3679c010cf5210e2"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py",
"blob_id": "1379d7cc10c5fafa877e3ea73be33d3eed57b449"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so",
"blob_id": "a58d380aa758b8e6842e89013229bee3711286ef"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py",
"blob_id": "9dee16955e9d988953733fae4e743d92886c92b1"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so",
"blob_id": "369150e0964eaca52c0c7906addf9f18d8ec7270"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py",
"blob_id": "0bac0403831e313bcf9cbab1a35c2cbe4d5ef08f"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so",
"blob_id": "05529e8bcee239db92984acb3e19926697c64a3f"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py",
"blob_id": "270fd3d0005a3e44dc6625c3ab4948a7fa7892bb"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so",
"blob_id": "cb6cccabe445cbf7bfd797b4645300e5a2a4ec38"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py",
"blob_id": "a517876400c08f9800107c61d6ca3f57e0bdc2e6"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so",
"blob_id": "bf93abf5555357ad397844421fcfc66ae0743166"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py",
"blob_id": "f49b90de8bda122b2049bf57f5012b60e05364fe"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so",
"blob_id": "0bbd1dc682174c9d7fba2ee7426e1183e668ab79"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py",
"blob_id": "f9b2a39308433746718b31f0d9830b27f72f5242"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so",
"blob_id": "d7fa42c3682924a46e9c5b4a7e847a6b4415c5c8"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py",
"blob_id": "27b44593d2252bfe5399c8dcd883aa497223f158"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so",
"blob_id": "4a4cccfd49090ac213bbf562a9c4bb2ff2920eb0"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py",
"blob_id": "ac89377661ed1c5f2eca40cf199a15209af0c05c"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so",
"blob_id": "4d87629674e87a746aaec4ccadb26bb2a72f2d43"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py",
"blob_id": "2f05f1ffd05c49971dfc9b45971efb5a055c7e52"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_attention_sfjvhlixssyce.abi3.so",
"blob_id": "ee6153972f28bd997e1fc4a7eaaf425fd5adc918"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_ops.py",
"blob_id": "530d483cdf8243f6c863ca49c0e87018634e69d0"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/__init__.py",
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_attention_g7oqtcveiuapk.abi3.so",
"blob_id": "fe58b4ce4158bf5ee55371329396ac8e573cfc85"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_custom_ops.py",
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_ops.py",
"blob_id": "1e504e67dd25c4aa79bcc509316f3f23e6e3e6ef"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/platforms.py",
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
}
]
},
{
"repo_id": "kernels-community/moe",
"sha": "f1e9385163758eb1934677cb9e94a59f2d87bb09",

View File

@ -38,6 +38,7 @@ requires = ["hf-kernels", "setuptools"]
build-backend = "setuptools.build_meta"
[tool.kernels.dependencies]
"kernels-community/attention" = ">=0.0.1"
"kernels-community/moe" = ">=0.0.3"
"kernels-community/quantization" = ">=0.0.2"
@ -74,14 +75,6 @@ gen = [
"mypy-protobuf>=3.6.0",
]
[tool.uv.sources]
attention-kernels = [
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
]
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -1,3 +1,4 @@
from hf_kernels import load_kernel
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
@ -107,7 +108,7 @@ def paged_attention(
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
import attention_kernels
attention_kernels = load_kernel("kernels-community/attention")
out = torch.empty_like(query)
@ -130,8 +131,8 @@ def paged_attention(
max_s,
None,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
)
else:
# Run PagedAttention V2.
@ -164,8 +165,8 @@ def paged_attention(
max_s,
None,
kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
)
return out

View File

@ -1,6 +1,7 @@
from typing import Tuple
from dataclasses import dataclass, field
from hf_kernels import load_kernel
from loguru import logger
import torch
@ -119,7 +120,7 @@ class KVCache:
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif self.dtype == torch.float8_e4m3fn and (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
log_once(logger.info, "Using FP8 KV cache scales")
@ -221,7 +222,7 @@ def paged_reshape_and_cache(
if SYSTEM == "cuda":
try:
import attention_kernels
attention_kernels = load_kernel("kernels-community/attention")
except Exception as e:
raise ImportError(
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
@ -232,7 +233,14 @@ def paged_reshape_and_cache(
kv_cache_dtype = "fp8"
attention_kernels.reshape_and_cache(
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
key,
value,
key_cache,
value_cache,
slots,
kv_cache_dtype,
torch.tensor(k_scale),
torch.tensor(v_scale),
)
elif SYSTEM == "rocm":
try:

View File

@ -148,7 +148,6 @@ def fp8_quantize(
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=quant_dtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.