diff --git a/flake.lock b/flake.lock index 512625de..b6cf7e53 100644 --- a/flake.lock +++ b/flake.lock @@ -2,16 +2,10 @@ "nodes": { "cachix": { "inputs": { - "devenv": [ - "crate2nix" - ], - "flake-compat": [ - "crate2nix" - ], + "devenv": ["crate2nix"], + "flake-compat": ["crate2nix"], "nixpkgs": "nixpkgs", - "pre-commit-hooks": [ - "crate2nix" - ] + "pre-commit-hooks": ["crate2nix"] }, "locked": { "lastModified": 1709700175, @@ -30,19 +24,10 @@ }, "cachix_2": { "inputs": { - "devenv": [ - "crate2nix", - "crate2nix_stable" - ], - "flake-compat": [ - "crate2nix", - "crate2nix_stable" - ], + "devenv": ["crate2nix", "crate2nix_stable"], + "flake-compat": ["crate2nix", "crate2nix_stable"], "nixpkgs": "nixpkgs_2", - "pre-commit-hooks": [ - "crate2nix", - "crate2nix_stable" - ] + "pre-commit-hooks": ["crate2nix", "crate2nix_stable"] }, "locked": { "lastModified": 1716549461, @@ -61,16 +46,8 @@ }, "cachix_3": { "inputs": { - "devenv": [ - "crate2nix", - "crate2nix_stable", - "crate2nix_stable" - ], - "flake-compat": [ - "crate2nix", - "crate2nix_stable", - "crate2nix_stable" - ], + "devenv": ["crate2nix", "crate2nix_stable", "crate2nix_stable"], + "flake-compat": ["crate2nix", "crate2nix_stable", "crate2nix_stable"], "nixpkgs": "nixpkgs_3", "pre-commit-hooks": [ "crate2nix", @@ -101,10 +78,7 @@ "flake-compat": "flake-compat_3", "flake-parts": "flake-parts_3", "nix-test-runner": "nix-test-runner_3", - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ], + "nixpkgs": ["tgi-nix", "nixpkgs"], "pre-commit-hooks": "pre-commit-hooks_3" }, "locked": { @@ -219,11 +193,7 @@ "devshell_2": { "inputs": { "flake-utils": "flake-utils_3", - "nixpkgs": [ - "crate2nix", - "crate2nix_stable", - "nixpkgs" - ] + "nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"] }, "locked": { "lastModified": 1717408969, @@ -242,10 +212,7 @@ "devshell_3": { "inputs": { "flake-utils": "flake-utils_4", - "nixpkgs": [ - "crate2nix", - "nixpkgs" - ] + "nixpkgs": ["crate2nix", "nixpkgs"] }, "locked": { "lastModified": 1711099426, @@ -343,11 +310,7 @@ }, "flake-parts_2": { "inputs": { - "nixpkgs-lib": [ - "crate2nix", - "crate2nix_stable", - "nixpkgs" - ] + "nixpkgs-lib": ["crate2nix", "crate2nix_stable", "nixpkgs"] }, "locked": { "lastModified": 1719745305, @@ -365,10 +328,7 @@ }, "flake-parts_3": { "inputs": { - "nixpkgs-lib": [ - "crate2nix", - "nixpkgs" - ] + "nixpkgs-lib": ["crate2nix", "nixpkgs"] }, "locked": { "lastModified": 1712014858, @@ -559,11 +519,7 @@ }, "gitignore_3": { "inputs": { - "nixpkgs": [ - "crate2nix", - "pre-commit-hooks", - "nixpkgs" - ] + "nixpkgs": ["crate2nix", "pre-commit-hooks", "nixpkgs"] }, "locked": { "lastModified": 1709087332, @@ -770,22 +726,10 @@ }, "pre-commit-hooks_2": { "inputs": { - "flake-compat": [ - "crate2nix", - "crate2nix_stable", - "flake-compat" - ], + "flake-compat": ["crate2nix", "crate2nix_stable", "flake-compat"], "gitignore": "gitignore_2", - "nixpkgs": [ - "crate2nix", - "crate2nix_stable", - "nixpkgs" - ], - "nixpkgs-stable": [ - "crate2nix", - "crate2nix_stable", - "nixpkgs" - ] + "nixpkgs": ["crate2nix", "crate2nix_stable", "nixpkgs"], + "nixpkgs-stable": ["crate2nix", "crate2nix_stable", "nixpkgs"] }, "locked": { "lastModified": 1719259945, @@ -803,20 +747,11 @@ }, "pre-commit-hooks_3": { "inputs": { - "flake-compat": [ - "crate2nix", - "flake-compat" - ], + "flake-compat": ["crate2nix", "flake-compat"], "flake-utils": "flake-utils_5", "gitignore": "gitignore_3", - "nixpkgs": [ - "crate2nix", - "nixpkgs" - ], - "nixpkgs-stable": [ - "crate2nix", - "nixpkgs" - ] + "nixpkgs": ["crate2nix", "nixpkgs"], + "nixpkgs-stable": ["crate2nix", "nixpkgs"] }, "locked": { "lastModified": 1712055707, @@ -837,20 +772,14 @@ "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", "nix-filter": "nix-filter", - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ], + "nixpkgs": ["tgi-nix", "nixpkgs"], "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } }, "rust-overlay": { "inputs": { - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ] + "nixpkgs": ["tgi-nix", "nixpkgs"] }, "locked": { "lastModified": 1738549608, @@ -978,16 +907,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1740036032, - "narHash": "sha256-nqo3U8uNlFIgrOz8wCfgk08Oi+RzQxxFDPipeVMyM/E=", + "lastModified": 1740049068, + "narHash": "sha256-heYzYOt+TSnRKHIV24s74yEjLkTbBfjNCWHdQEX++eI=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "e9fb0e818a7e9a54cdab8d9c7c0cef5037fe084a", + "rev": "143e8451efa22b120f97e6698508e9a0aed82769", "type": "github" }, "original": { "owner": "huggingface", - "ref": "flashinfer-0.2.0.post2", + "ref": "hub-rotary", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 6068dc5f..943bf736 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-0.2.0.post2"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/hub-rotary"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/nix/server.nix b/nix/server.nix index 98193cac..0640fe3a 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -11,7 +11,6 @@ flashinfer, flash-attn, flash-attn-layer-norm, - flash-attn-rotary, flash-attn-v1, grpc-interceptor, grpcio-reflection, @@ -36,6 +35,7 @@ pydantic, quantization, quantization-eetq, + rotary, safetensors, tokenizers, torch, @@ -87,7 +87,6 @@ buildPythonPackage { flashinfer flash-attn flash-attn-layer-norm - flash-attn-rotary grpc-interceptor grpcio-reflection grpcio-status @@ -111,6 +110,7 @@ buildPythonPackage { pydantic quantization quantization-eetq + rotary safetensors sentencepiece tokenizers diff --git a/server/hf-kernels.lock b/server/hf-kernels.lock index 5254cb0c..7dc75943 100644 --- a/server/hf-kernels.lock +++ b/server/hf-kernels.lock @@ -6934,5 +6934,155 @@ "blob_id": "005b5a6e3cd5f7bcfd4aa5d7d80d60a5ed9fab88" } ] + }, + { + "repo_id": "kernels-community/rotary", + "sha": "4db658e027ec752840bb3f557ee076413b8db03f", + "files": [ + { + "filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/_ops.py", + "blob_id": "4fe035c87ea1300ffedcfce17338167dd946e0e8" + }, + { + "filename": "build/torch25-cxx11-cu118-x86_64-linux/rotary/_rotary_5yzc45v7kk3yu.abi3.so", + "blob_id": "f315754ccb3e8b9dfb4d8954aefaac61b2a4e8bc" + }, + { + "filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/_ops.py", + "blob_id": "d45359065a7cdc43e2d512d38fc0bfcd88138835" + }, + { + "filename": "build/torch25-cxx11-cu121-x86_64-linux/rotary/_rotary_tbiepw2a2ep3e.abi3.so", + "blob_id": "9bc986ca760b6a57d05e891c3def1769341a2c29" + }, + { + "filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/_ops.py", + "blob_id": "7421978682125139f1169f3f71789e0cb44d3b45" + }, + { + "filename": "build/torch25-cxx11-cu124-x86_64-linux/rotary/_rotary_6w5syhrhmerj6.abi3.so", + "blob_id": "35caf7d755000ca2afac33cc7d4f344112751f9f" + }, + { + "filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/_ops.py", + "blob_id": "d6f569b471eae628e738b3504c8a9a18b4973d97" + }, + { + "filename": "build/torch25-cxx98-cu118-x86_64-linux/rotary/_rotary_joujmbgvsytzg.abi3.so", + "blob_id": "30c4aaa83f2549f7363631a51b3341cdf0612f15" + }, + { + "filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/_ops.py", + "blob_id": "f1f71f34bf0f3c5dffb7c147b48f6396f8054310" + }, + { + "filename": "build/torch25-cxx98-cu121-x86_64-linux/rotary/_rotary_mi2o7e7sishyw.abi3.so", + "blob_id": "e022e9c1101bdb89d43913979107f7a56717ea6d" + }, + { + "filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/_ops.py", + "blob_id": "a46c19bd5adfb85d5b7795b3b9277e416f31d8ce" + }, + { + "filename": "build/torch25-cxx98-cu124-x86_64-linux/rotary/_rotary_rngiohfhfwuge.abi3.so", + "blob_id": "1621cba0150465f67aa931fe3a55e38928b48bcb" + }, + { + "filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/_ops.py", + "blob_id": "3296d23431d1ec084e8644ff5d3d203a74d82ea1" + }, + { + "filename": "build/torch26-cxx11-cu118-x86_64-linux/rotary/_rotary_alv7mzltcxxpq.abi3.so", + "blob_id": "2f8b3b93bb7c8fae22c8e08c67771683f549f170" + }, + { + "filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/_ops.py", + "blob_id": "0bae33b64c71d6a6ad748be66a410a662bb5b28a" + }, + { + "filename": "build/torch26-cxx11-cu124-x86_64-linux/rotary/_rotary_c4eyapeep6gty.abi3.so", + "blob_id": "3a15076dbd1f9a05f1089cc2cb13c03e5838f2b9" + }, + { + "filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/_ops.py", + "blob_id": "5c5d7e4497962e0a3e9531ec6a5fdb18e995e0f8" + }, + { + "filename": "build/torch26-cxx11-cu126-x86_64-linux/rotary/_rotary_lodp6xeztste6.abi3.so", + "blob_id": "efc4a4a0001fba7b0743d1d4a9774d1fc9089ee5" + }, + { + "filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/_ops.py", + "blob_id": "3cdda1ceda90e76b30b08cae6ad718aa2c2ec3ef" + }, + { + "filename": "build/torch26-cxx98-cu118-x86_64-linux/rotary/_rotary_z27mls7mz4e7m.abi3.so", + "blob_id": "43b0f08ea035cd2fadb6e119802e7d841a523246" + }, + { + "filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/_ops.py", + "blob_id": "914f1bb6f9499d0245c3f47345f3b95582be28b2" + }, + { + "filename": "build/torch26-cxx98-cu124-x86_64-linux/rotary/_rotary_3bktke4p3hz3a.abi3.so", + "blob_id": "29c0bba399e6f348ad8baf93daa5166a6ec6994a" + }, + { + "filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/__init__.py", + "blob_id": "eba8039e210c8b710c5c663ef4e7930757f271be" + }, + { + "filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/_ops.py", + "blob_id": "11056cd0ed09530830b614b8207cbb7fa7ef3288" + }, + { + "filename": "build/torch26-cxx98-cu126-x86_64-linux/rotary/_rotary_fvednlzeqgg5s.abi3.so", + "blob_id": "8c7c34d9e603640576ba522dcbed341c0d780a9c" + } + ] } ] diff --git a/server/pyproject.toml b/server/pyproject.toml index 37cb6b1a..bda9df1b 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -43,6 +43,7 @@ build-backend = "setuptools.build_meta" "kernels-community/moe" = ">=0.1.1" "kernels-community/quantization" = ">=0.0.3" "kernels-community/quantization-eetq" = ">=0.0.1" +"kernels-community/rotary" = ">=0.0.1" [project.scripts] text-generation-server = "text_generation_server.cli:app" diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index cd22a1f1..d312a8b8 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -5,7 +5,9 @@ from torch import nn from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": - import rotary_emb + from text_generation_server.utils.kernels import load_kernel + + rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary") elif SYSTEM == "rocm": import vllm._custom_ops as ops elif SYSTEM == "ipex": @@ -54,12 +56,12 @@ class PositionRotaryEmbedding(nn.Module): q1 = query[..., :rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False) k1 = key[..., :rotary_dim] k2 = key[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif SYSTEM == "rocm": # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index ece15e94..ce68fc12 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -63,17 +63,19 @@ class CohereRotary(PositionRotaryEmbedding): ): # Such controlflows may add some overhead. if SYSTEM == "cuda": - import rotary_emb + from text_generation_server.utils.kernels import load_kernel + + rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary") q1 = query[..., ::2] q2 = query[..., 1::2] - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False) k1 = key[..., ::2] k2 = key[..., 1::2] - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif SYSTEM == "rocm": import vllm._custom_ops as ops diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 45b90679..79f5ccb6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -79,17 +79,19 @@ class GPTJRotary(PositionRotaryEmbedding): ): # Such controlflows may add some overhead. if SYSTEM == "cuda": - import rotary_emb + from text_generation_server.utils.kernels import load_kernel + + rotary = load_kernel(module="rotary", repo_id="kernels-community/rotary") q1 = query[..., ::2] q2 = query[..., 1::2] - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False) k1 = key[..., ::2] k2 = key[..., 1::2] - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif SYSTEM == "rocm": import vllm._custom_ops as ops