From 8c0f9312f3493019ffb8f38e1c75bf9661ee5d51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 26 Sep 2024 13:34:01 +0000 Subject: [PATCH] nix: add flash-attn-v1 to the server environment --- flake.lock | 14 ++++++------ flake.nix | 45 ++++-------------------------------- nix/impure-shell.nix | 54 ++++++++++++++++++++++++++++++++++++++++++++ nix/server.nix | 1 + 4 files changed, 66 insertions(+), 48 deletions(-) create mode 100644 nix/impure-shell.nix diff --git a/flake.lock b/flake.lock index d811be5e..14e23b77 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1726743157, - "narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=", - "owner": "danieldk", - "repo": "tgi-nix", - "rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f", + "lastModified": 1727353315, + "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=", + "owner": "huggingface", + "repo": "text-generation-inference-nix", + "rev": "1d42c4125ebafb87707118168995675cc5050b9d", "type": "github" }, "original": { - "owner": "danieldk", - "repo": "tgi-nix", + "owner": "huggingface", + "repo": "text-generation-inference-nix", "type": "github" } } diff --git a/flake.nix b/flake.nix index 260b2554..1b396453 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:danieldk/tgi-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { @@ -132,49 +132,12 @@ pre-commit ruff ]); - }; - impure = mkShell { - buildInputs = - [ - openssl.dev - pkg-config - (rust-bin.stable.latest.default.override { - extensions = [ - "rust-analyzer" - "rust-src" - ]; - }) - protobuf - ] - ++ (with python3.pkgs; [ - venvShellHook - docker - pip - ipdb - click - pyright - pytest - pytest-asyncio - redocly - ruff - syrupy - ]); + impure = callPackage ./nix/impure-shell.nix { inherit server; }; - inputsFrom = [ server ]; - - venvDir = "./.venv"; - - postVenvCreation = '' - unset SOURCE_DATE_EPOCH - ( cd server ; python -m pip install --no-dependencies -e . ) - ( cd clients/python ; python -m pip install --no-dependencies -e . ) - ''; - postShellHook = '' - unset SOURCE_DATE_EPOCH - export PATH=$PATH:~/.cargo/bin - ''; + impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { + server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; }; }; diff --git a/nix/impure-shell.nix b/nix/impure-shell.nix new file mode 100644 index 00000000..a4dad4ba --- /dev/null +++ b/nix/impure-shell.nix @@ -0,0 +1,54 @@ +{ + mkShell, + openssl, + pkg-config, + protobuf, + python3, + pyright, + redocly, + ruff, + rust-bin, + server, +}: + +mkShell { + buildInputs = + [ + openssl.dev + pkg-config + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) + protobuf + pyright + redocly + ruff + ] + ++ (with python3.pkgs; [ + venvShellHook + docker + pip + ipdb + click + pytest + pytest-asyncio + syrupy + ]); + + inputsFrom = [ server ]; + + venvDir = "./.venv"; + + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ( cd server ; python -m pip install --no-dependencies -e . ) + ( cd clients/python ; python -m pip install --no-dependencies -e . ) + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin + ''; +} diff --git a/nix/server.nix b/nix/server.nix index 5921da7f..7406d563 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -13,6 +13,7 @@ flash-attn, flash-attn-layer-norm, flash-attn-rotary, + flash-attn-v1, grpc-interceptor, grpcio-reflection, grpcio-status,