mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Use attention kernels from the Hub
This commit is contained in:
parent
758ff3c598
commit
b267caa537
@ -3,7 +3,6 @@
|
||||
buildPythonPackage,
|
||||
poetry-core,
|
||||
mypy-protobuf,
|
||||
attention-kernels,
|
||||
awq-inference-engine,
|
||||
causal-conv1d,
|
||||
compressed-tensors,
|
||||
@ -23,7 +22,6 @@
|
||||
hf-transfer,
|
||||
loguru,
|
||||
mamba-ssm,
|
||||
moe-kernels,
|
||||
opentelemetry-api,
|
||||
opentelemetry-exporter-otlp,
|
||||
opentelemetry-instrumentation-grpc,
|
||||
@ -78,7 +76,6 @@ buildPythonPackage {
|
||||
pythonRemoveDeps = [ "scipy" ];
|
||||
|
||||
dependencies = [
|
||||
attention-kernels
|
||||
awq-inference-engine
|
||||
eetq
|
||||
causal-conv1d
|
||||
|
@ -1,4 +1,250 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-community/attention",
|
||||
"sha": "20100e6a97f0fa1465560aa21eecbf4b04d3d93a",
|
||||
"files": [
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_attention_6yvgebnqctora.abi3.so",
|
||||
"blob_id": "29733cfb726d11a1d278fb0f3679c010cf5210e2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "1379d7cc10c5fafa877e3ea73be33d3eed57b449"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu118-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_attention_4jg2igd54wzge.abi3.so",
|
||||
"blob_id": "a58d380aa758b8e6842e89013229bee3711286ef"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "9dee16955e9d988953733fae4e743d92886c92b1"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu121-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_attention_syg6kbhkhc4xk.abi3.so",
|
||||
"blob_id": "369150e0964eaca52c0c7906addf9f18d8ec7270"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "0bac0403831e313bcf9cbab1a35c2cbe4d5ef08f"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx11-cu124-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_attention_hhzgzhvc7zviy.abi3.so",
|
||||
"blob_id": "05529e8bcee239db92984acb3e19926697c64a3f"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "270fd3d0005a3e44dc6625c3ab4948a7fa7892bb"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu118-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_attention_gbi5gm244waic.abi3.so",
|
||||
"blob_id": "cb6cccabe445cbf7bfd797b4645300e5a2a4ec38"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "a517876400c08f9800107c61d6ca3f57e0bdc2e6"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu121-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_attention_ill75rmpj7yds.abi3.so",
|
||||
"blob_id": "bf93abf5555357ad397844421fcfc66ae0743166"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "f49b90de8bda122b2049bf57f5012b60e05364fe"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch25-cxx98-cu124-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_attention_6qe5ft3kiteru.abi3.so",
|
||||
"blob_id": "0bbd1dc682174c9d7fba2ee7426e1183e668ab79"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "f9b2a39308433746718b31f0d9830b27f72f5242"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu118-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_attention_ftq3cjdxqfw4m.abi3.so",
|
||||
"blob_id": "d7fa42c3682924a46e9c5b4a7e847a6b4415c5c8"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "27b44593d2252bfe5399c8dcd883aa497223f158"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu124-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_attention_lkibbjh726iwm.abi3.so",
|
||||
"blob_id": "4a4cccfd49090ac213bbf562a9c4bb2ff2920eb0"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "ac89377661ed1c5f2eca40cf199a15209af0c05c"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx11-cu126-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_attention_vbhagz24hyij6.abi3.so",
|
||||
"blob_id": "4d87629674e87a746aaec4ccadb26bb2a72f2d43"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "2f05f1ffd05c49971dfc9b45971efb5a055c7e52"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu118-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_attention_sfjvhlixssyce.abi3.so",
|
||||
"blob_id": "ee6153972f28bd997e1fc4a7eaaf425fd5adc918"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "530d483cdf8243f6c863ca49c0e87018634e69d0"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu124-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/__init__.py",
|
||||
"blob_id": "9de56043369487facc1f163df6bd319c9806e5ca"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_attention_g7oqtcveiuapk.abi3.so",
|
||||
"blob_id": "fe58b4ce4158bf5ee55371329396ac8e573cfc85"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_custom_ops.py",
|
||||
"blob_id": "a0c0b8db085468dee5100c98d14106a9ee917bf2"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/_ops.py",
|
||||
"blob_id": "1e504e67dd25c4aa79bcc509316f3f23e6e3e6ef"
|
||||
},
|
||||
{
|
||||
"filename": "build/torch26-cxx98-cu126-x86_64-linux/attention/platforms.py",
|
||||
"blob_id": "aa06132e74cd7fb634044a76e528979b02a3559b"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"repo_id": "kernels-community/moe",
|
||||
"sha": "f1e9385163758eb1934677cb9e94a59f2d87bb09",
|
||||
|
@ -38,6 +38,7 @@ requires = ["hf-kernels", "setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-community/attention" = ">=0.0.1"
|
||||
"kernels-community/moe" = ">=0.0.3"
|
||||
"kernels-community/quantization" = ">=0.0.2"
|
||||
|
||||
@ -74,14 +75,6 @@ gen = [
|
||||
"mypy-protobuf>=3.6.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
attention-kernels = [
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp39-cp39-linux_x86_64.whl", marker = "python_version == '3.9'" },
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp310-cp310-linux_x86_64.whl", marker = "python_version == '3.10'" },
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp311-cp311-linux_x86_64.whl", marker = "python_version == '3.11'" },
|
||||
{ url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.5-cp312-cp312-linux_x86_64.whl", marker = "python_version == '3.12'" },
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from hf_kernels import load_kernel
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
@ -107,7 +108,7 @@ def paged_attention(
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
import attention_kernels
|
||||
attention_kernels = load_kernel("kernels-community/attention")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
@ -130,8 +131,8 @@ def paged_attention(
|
||||
max_s,
|
||||
None,
|
||||
kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
|
||||
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
@ -164,8 +165,8 @@ def paged_attention(
|
||||
max_s,
|
||||
None,
|
||||
kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),
|
||||
torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),
|
||||
)
|
||||
return out
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from hf_kernels import load_kernel
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
@ -119,7 +120,7 @@ class KVCache:
|
||||
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||
return False
|
||||
elif self.dtype == torch.float8_e4m3fn and (
|
||||
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||
(ATTENTION in ("paged", "flashinfer") and SYSTEM == "cuda")
|
||||
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||
):
|
||||
log_once(logger.info, "Using FP8 KV cache scales")
|
||||
@ -221,7 +222,7 @@ def paged_reshape_and_cache(
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
import attention_kernels
|
||||
attention_kernels = load_kernel("kernels-community/attention")
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import attention_kernels. Make sure your installation is correct. Complete error: {e}"
|
||||
@ -232,7 +233,14 @@ def paged_reshape_and_cache(
|
||||
kv_cache_dtype = "fp8"
|
||||
|
||||
attention_kernels.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_cache_dtype,
|
||||
torch.tensor(k_scale),
|
||||
torch.tensor(v_scale),
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
try:
|
||||
|
@ -148,7 +148,6 @@ def fp8_quantize(
|
||||
shape = weight.shape
|
||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
dtype=quant_dtype,
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
|
Loading…
Reference in New Issue
Block a user