Use rotary kernel from the Hub (#3041)

This commit is contained in:
Daniël de Kok 2025-02-21 13:55:31 +01:00 committed by GitHub
parent 1cae3197c4
commit 97c5f7e685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 195 additions and 109 deletions

View File

@ -2,16 +2,10 @@
"nodes": { "nodes": {
"cachix": { "cachix": {
"inputs": { "inputs": {
"devenv": [ "devenv": ["crate2nix"],
"crate2nix" "flake-compat": ["crate2nix"],
],
"flake-compat": [
"crate2nix"
],
"nixpkgs": "nixpkgs", "nixpkgs": "nixpkgs",
"pre-commit-hooks": [ "pre-commit-hooks": ["crate2nix"]
"crate2nix"
]
}, },
"locked": { "locked": {
"lastModified": 1709700175, "lastModified": 1709700175,
@ -30,19 +24,10 @@
}, },
"cachix_2": { "cachix_2": {
"inputs": { "inputs": {
"devenv": [ "devenv": ["crate2nix", "crate2nix_stable"],
"crate2nix", "flake-compat": ["crate2nix", "crate2nix_stable"],
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable"
],
"nixpkgs": "nixpkgs_2", "nixpkgs": "nixpkgs_2",
"pre-commit-hooks": [ "pre-commit-hooks": ["crate2nix", "crate2nix_stable"]
"crate2nix",
"crate2nix_stable"
]
}, },
"locked": { "locked": {
"lastModified": 1716549461, "lastModified": 1716549461,
@ -61,16 +46,8 @@
}, },
"cachix_3": { "cachix_3": {
"inputs": { "inputs": {
"devenv": [ "devenv": ["crate2nix", "crate2nix_stable", "crate2nix_stable"],
"crate2nix", "flake-compat": ["crate2nix", "crate2nix_stable", "crate2nix_stable"],
"crate2nix_stable",
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
],
"nixpkgs": "nixpkgs_3", "nixpkgs": "nixpkgs_3",
"pre-commit-hooks": [ "pre-commit-hooks": [
"crate2nix", "crate2nix",
@ -101,10 +78,7 @@
"flake-compat": "flake-compat_3", "flake-compat": "flake-compat_3",
"flake-parts": "flake-parts_3", "flake-parts": "flake-parts_3",
"nix-test-runner": "nix-test-runner_3", "nix-test-runner": "nix-test-runner_3",
"nixpkgs": [ "nixpkgs": ["tgi-nix", "nixpkgs"],
"tgi-nix",
"nixpkgs"
],
"pre-commit-hooks": "pre-commit-hooks_3" "pre-commit-hooks": "pre-commit-hooks_3"
}, },
"locked": { "locked": {
@ -219,11 +193,7 @@
"devshell_2": { "devshell_2": {
"inputs": { "inputs": {
"flake-utils": "flake-utils_3", "flake-utils": "flake-utils_3",
"nixpkgs": [ "nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"]
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1717408969, "lastModified": 1717408969,
@ -242,10 +212,7 @@
"devshell_3": { "devshell_3": {
"inputs": { "inputs": {
"flake-utils": "flake-utils_4", "flake-utils": "flake-utils_4",
"nixpkgs": [ "nixpkgs": ["crate2nix", "nixpkgs"]
"crate2nix",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1711099426, "lastModified": 1711099426,
@ -343,11 +310,7 @@
}, },
"flake-parts_2": { "flake-parts_2": {
"inputs": { "inputs": {
"nixpkgs-lib": [ "nixpkgs-lib": ["crate2nix", "crate2nix_stable", "nixpkgs"]
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1719745305, "lastModified": 1719745305,
@ -365,10 +328,7 @@
}, },
"flake-parts_3": { "flake-parts_3": {
"inputs": { "inputs": {
"nixpkgs-lib": [ "nixpkgs-lib": ["crate2nix", "nixpkgs"]
"crate2nix",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1712014858, "lastModified": 1712014858,
@ -559,11 +519,7 @@
}, },
"gitignore_3": { "gitignore_3": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": ["crate2nix", "pre-commit-hooks", "nixpkgs"]
"crate2nix",
"pre-commit-hooks",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1709087332, "lastModified": 1709087332,
@ -770,22 +726,10 @@
}, },
"pre-commit-hooks_2": { "pre-commit-hooks_2": {
"inputs": { "inputs": {
"flake-compat": [ "flake-compat": ["crate2nix", "crate2nix_stable", "flake-compat"],
"crate2nix",
"crate2nix_stable",
"flake-compat"
],
"gitignore": "gitignore_2", "gitignore": "gitignore_2",
"nixpkgs": [ "nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"],
"crate2nix", "nixpkgs-stable": ["crate2nix", "crate2nix_stable", "nixpkgs"]
"crate2nix_stable",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1719259945, "lastModified": 1719259945,
@ -803,20 +747,11 @@
}, },
"pre-commit-hooks_3": { "pre-commit-hooks_3": {
"inputs": { "inputs": {
"flake-compat": [ "flake-compat": ["crate2nix", "flake-compat"],
"crate2nix",
"flake-compat"
],
"flake-utils": "flake-utils_5", "flake-utils": "flake-utils_5",
"gitignore": "gitignore_3", "gitignore": "gitignore_3",
"nixpkgs": [ "nixpkgs": ["crate2nix", "nixpkgs"],
"crate2nix", "nixpkgs-stable": ["crate2nix", "nixpkgs"]
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1712055707, "lastModified": 1712055707,
@ -837,20 +772,14 @@
"crate2nix": "crate2nix", "crate2nix": "crate2nix",
"flake-utils": "flake-utils_6", "flake-utils": "flake-utils_6",
"nix-filter": "nix-filter", "nix-filter": "nix-filter",
"nixpkgs": [ "nixpkgs": ["tgi-nix", "nixpkgs"],
"tgi-nix",
"nixpkgs"
],
"rust-overlay": "rust-overlay", "rust-overlay": "rust-overlay",
"tgi-nix": "tgi-nix" "tgi-nix": "tgi-nix"
} }
}, },
"rust-overlay": { "rust-overlay": {
"inputs": { "inputs": {
"nixpkgs": [ "nixpkgs": ["tgi-nix", "nixpkgs"]
"tgi-nix",
"nixpkgs"
]
}, },
"locked": { "locked": {
"lastModified": 1738549608, "lastModified": 1738549608,
@ -978,16 +907,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1740036032, "lastModified": 1740049068,
"narHash": "sha256-nqo3U8uNlFIgrOz8wCfgk08Oi+RzQxxFDPipeVMyM/E=", "narHash": "sha256-heYzYOt+TSnRKHIV24s74yEjLkTbBfjNCWHdQEX++eI=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "e9fb0e818a7e9a54cdab8d9c7c0cef5037fe084a", "rev": "143e8451efa22b120f97e6698508e9a0aed82769",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "flashinfer-0.2.0.post2", "ref": "hub-rotary",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-0.2.0.post2"; tgi-nix.url = "github:huggingface/text-generation-inference-nix/hub-rotary";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -11,7 +11,6 @@
flashinfer, flashinfer,
flash-attn, flash-attn,
flash-attn-layer-norm, flash-attn-layer-norm,
flash-attn-rotary,
flash-attn-v1, flash-attn-v1,
grpc-interceptor, grpc-interceptor,
grpcio-reflection, grpcio-reflection,
@ -36,6 +35,7 @@
pydantic, pydantic,
quantization, quantization,
quantization-eetq, quantization-eetq,
rotary,
safetensors, safetensors,
tokenizers, tokenizers,
torch, torch,
@ -87,7 +87,6 @@ buildPythonPackage {
flashinfer flashinfer
flash-attn flash-attn
flash-attn-layer-norm flash-attn-layer-norm
flash-attn-rotary
grpc-interceptor grpc-interceptor
grpcio-reflection grpcio-reflection
grpcio-status grpcio-status
@ -111,6 +110,7 @@ buildPythonPackage {
pydantic pydantic
quantization quantization
quantization-eetq quantization-eetq
rotary
safetensors safetensors
sentencepiece sentencepiece
tokenizers tokenizers

View File

@ -6934,5 +6934,155 @@
"blob_id": "005b5a6e3cd5f7bcfd4aa5d7d80d60a5ed9fab88" "blob_id": "005b5a6e3cd5f7bcfd4aa5d7d80d60a5ed9fab88"
} }
] ]
},
{
"repo_id": "kernels-community/rotary",
"sha": "4db658e027ec752840bb3f557ee076413b8db03f",
"files": [
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/_ops.py",
"blob_id": "4fe035c87ea1300ffedcfce17338167dd946e0e8"
},
{
"filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/_rotary_5yzc45v7kk3yu.abi3.so",
"blob_id": "f315754ccb3e8b9dfb4d8954aefaac61b2a4e8bc"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/_ops.py",
"blob_id": "d45359065a7cdc43e2d512d38fc0bfcd88138835"
},
{
"filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/_rotary_tbiepw2a2ep3e.abi3.so",
"blob_id": "9bc986ca760b6a57d05e891c3def1769341a2c29"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/_ops.py",
"blob_id": "7421978682125139f1169f3f71789e0cb44d3b45"
},
{
"filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/_rotary_6w5syhrhmerj6.abi3.so",
"blob_id": "35caf7d755000ca2afac33cc7d4f344112751f9f"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/_ops.py",
"blob_id": "d6f569b471eae628e738b3504c8a9a18b4973d97"
},
{
"filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/_rotary_joujmbgvsytzg.abi3.so",
"blob_id": "30c4aaa83f2549f7363631a51b3341cdf0612f15"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/_ops.py",
"blob_id": "f1f71f34bf0f3c5dffb7c147b48f6396f8054310"
},
{
"filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/_rotary_mi2o7e7sishyw.abi3.so",
"blob_id": "e022e9c1101bdb89d43913979107f7a56717ea6d"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/_ops.py",
"blob_id": "a46c19bd5adfb85d5b7795b3b9277e416f31d8ce"
},
{
"filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/_rotary_rngiohfhfwuge.abi3.so",
"blob_id": "1621cba0150465f67aa931fe3a55e38928b48bcb"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/_ops.py",
"blob_id": "3296d23431d1ec084e8644ff5d3d203a74d82ea1"
},
{
"filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/_rotary_alv7mzltcxxpq.abi3.so",
"blob_id": "2f8b3b93bb7c8fae22c8e08c67771683f549f170"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/_ops.py",
"blob_id": "0bae33b64c71d6a6ad748be66a410a662bb5b28a"
},
{
"filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/_rotary_c4eyapeep6gty.abi3.so",
"blob_id": "3a15076dbd1f9a05f1089cc2cb13c03e5838f2b9"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/_ops.py",
"blob_id": "5c5d7e4497962e0a3e9531ec6a5fdb18e995e0f8"
},
{
"filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/_rotary_lodp6xeztste6.abi3.so",
"blob_id": "efc4a4a0001fba7b0743d1d4a9774d1fc9089ee5"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/_ops.py",
"blob_id": "3cdda1ceda90e76b30b08cae6ad718aa2c2ec3ef"
},
{
"filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/_rotary_z27mls7mz4e7m.abi3.so",
"blob_id": "43b0f08ea035cd2fadb6e119802e7d841a523246"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/_ops.py",
"blob_id": "914f1bb6f9499d0245c3f47345f3b95582be28b2"
},
{
"filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/_rotary_3bktke4p3hz3a.abi3.so",
"blob_id": "29c0bba399e6f348ad8baf93daa5166a6ec6994a"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/__init__.py",
"blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/_ops.py",
"blob_id": "11056cd0ed09530830b614b8207cbb7fa7ef3288"
},
{
"filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/_rotary_fvednlzeqgg5s.abi3.so",
"blob_id": "8c7c34d9e603640576ba522dcbed341c0d780a9c"
}
]
} }
] ]

View File

@ -43,6 +43,7 @@ build-backend = "setuptools.build_meta"
"kernels-community/moe" = ">=0.1.1" "kernels-community/moe" = ">=0.1.1"
"kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization" = ">=0.0.3"
"kernels-community/quantization-eetq" = ">=0.0.1" "kernels-community/quantization-eetq" = ">=0.0.1"
"kernels-community/rotary" = ">=0.0.1"
[project.scripts] [project.scripts]
text-generation-server = "text_generation_server.cli:app" text-generation-server = "text_generation_server.cli:app"

View File

@ -5,7 +5,9 @@ from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
import rotary_emb from text_generation_server.utils.kernels import load_kernel
rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary")
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
import vllm._custom_ops as ops import vllm._custom_ops as ops
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
@ -54,12 +56,12 @@ class PositionRotaryEmbedding(nn.Module):
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim] k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim] k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773

View File

@ -63,17 +63,19 @@ class CohereRotary(PositionRotaryEmbedding):
): ):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if SYSTEM == "cuda": if SYSTEM == "cuda":
import rotary_emb from text_generation_server.utils.kernels import load_kernel
rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary")
q1 = query[..., ::2] q1 = query[..., ::2]
q2 = query[..., 1::2] q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2] k1 = key[..., ::2]
k2 = key[..., 1::2] k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
import vllm._custom_ops as ops import vllm._custom_ops as ops

View File

@ -79,17 +79,19 @@ class GPTJRotary(PositionRotaryEmbedding):
): ):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if SYSTEM == "cuda": if SYSTEM == "cuda":
import rotary_emb from text_generation_server.utils.kernels import load_kernel
rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary")
q1 = query[..., ::2] q1 = query[..., ::2]
q2 = query[..., 1::2] q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2] k1 = key[..., ::2]
k2 = key[..., 1::2] k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
import vllm._custom_ops as ops import vllm._custom_ops as ops