Add impureWithCuda dev shell

This shell is handy when developing some kernels jointly with TGI - it
adds nvcc and a bunch of commonly-used CUDA libraries to the environment.

We don't add this to the normal impure shell to keep the development
environment as clean as possible (avoid accidental dependencies, etc.).
This commit is contained in:
Daniël de Kok 2024-10-22 07:55:07 +00:00
parent 058d3061f7
commit 7aa90c58dc
2 changed files with 46 additions and 3 deletions

View File

@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};

View File

@ -1,7 +1,12 @@
{
lib,
mkShell,
black,
cmake,
isort,
ninja,
which,
cudaPackages,
openssl,
pkg-config,
protobuf,
@ -11,14 +16,17 @@
ruff,
rust-bin,
server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}:
mkShell {
buildInputs =
nativeBuildInputs =
[
black
isort
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly
ruff
]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [
venvShellHook
docker
@ -40,10 +61,26 @@ mkShell {
pytest
pytest-asyncio
syrupy
]);
])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvtx
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv";
postVenvCreation = ''
@ -51,6 +88,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin