diff --git a/flake.lock b/flake.lock index 6e1875100..7c7dff78a 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1743931123, - "narHash": "sha256-MDQrbJkweLYsMYh44Gx+c1gAZOCR1fmZF1lkavAHDto=", + "lastModified": 1744365621, + "narHash": "sha256-HgO+5SmiLABiRMSe0p9XYeF0xP50TL0jMP9ueEhJWlU=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "1ad3feaadfdedca90278ee7676bca15019519189", + "rev": "e6eb0da8a53e486bf50ba30ce8bd7c770d319076", "type": "github" }, "original": { "owner": "huggingface", + "ref": "flashinfer-0.2.5", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index c733cdd24..c4a999fa5 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/flashinfer-0.2.5"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer index f311a6569..71ac3edaa 100644 --- a/server/Makefile-flashinfer +++ b/server/Makefile-flashinfer @@ -3,4 +3,4 @@ install-flashinfer: # `pip install flashinfer` cannot resolve it. uv pip install fsspec sympy==1.13.1 numpy uv pip install -U setuptools - TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;8.9;9.0+PTX" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.0.post2#egg=flashinfer-python --no-build-isolation + TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;8.9;9.0+PTX" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.5#egg=flashinfer-python --no-build-isolation diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index f78475d51..a4232a93c 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -92,7 +92,7 @@ def use_prefill_with_paged_kv_state( custom_mask=custom_mask, num_qo_heads=num_heads, num_kv_heads=num_kv_heads, - head_dim=head_size, + head_dim_qk=head_size, kv_data_type=kv_dtype, q_data_type=q_dtype, page_size=page_size,