From 9c9ef37c56935a4fb98138236b42c25ffc18be4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 22 Oct 2024 11:02:55 +0200 Subject: [PATCH] Add `impureWithCuda` dev shell (#2677) * Add `impureWithCuda` dev shell This shell is handy when developing some kernels jointly with TGI - it adds nvcc and a bunch of commonly-used CUDA libraries to the environment. We don't add this to the normal impure shell to keep the development environment as clean as possible (avoid accidental dependencies, etc.). * Add cuDNN --- flake.nix | 5 +++++ nix/impure-shell.nix | 45 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/flake.nix b/flake.nix index edef442f..f26a983e 100644 --- a/flake.nix +++ b/flake.nix @@ -137,6 +137,11 @@ impure = callPackage ./nix/impure-shell.nix { inherit server; }; + impureWithCuda = callPackage ./nix/impure-shell.nix { + inherit server; + withCuda = true; + }; + impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; }; diff --git a/nix/impure-shell.nix b/nix/impure-shell.nix index abed544a..9df4b111 100644 --- a/nix/impure-shell.nix +++ b/nix/impure-shell.nix @@ -1,7 +1,12 @@ { + lib, mkShell, black, + cmake, isort, + ninja, + which, + cudaPackages, openssl, pkg-config, protobuf, @@ -11,14 +16,17 @@ ruff, rust-bin, server, + + # Enable dependencies for building CUDA packages. Useful for e.g. + # developing marlin/moe-kernels in-place. + withCuda ? false, }: mkShell { - buildInputs = + nativeBuildInputs = [ black isort - openssl.dev pkg-config (rust-bin.stable.latest.default.override { extensions = [ @@ -31,6 +39,19 @@ mkShell { redocly ruff ] + ++ (lib.optionals withCuda [ + cmake + ninja + which + + # For most Torch-based extensions, setting CUDA_HOME is enough, but + # some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH. + cudaPackages.cuda_nvcc + ]); + buildInputs = + [ + openssl.dev + ] ++ (with python3.pkgs; [ venvShellHook docker @@ -40,10 +61,27 @@ mkShell { pytest pytest-asyncio syrupy - ]); + ]) + ++ (lib.optionals withCuda ( + with cudaPackages; + [ + cuda_cccl + cuda_cudart + cuda_nvtx + cudnn + libcublas + libcusolver + libcusparse + ] + )); inputsFrom = [ server ]; + env = lib.optionalAttrs withCuda { + CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; + TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities; + }; + venvDir = "./.venv"; postVenvCreation = '' @@ -51,6 +89,7 @@ mkShell { ( cd server ; python -m pip install --no-dependencies -e . ) ( cd clients/python ; python -m pip install --no-dependencies -e . ) ''; + postShellHook = '' unset SOURCE_DATE_EPOCH export PATH=$PATH:~/.cargo/bin