text-generation-inference/flake.nix
Daniël de Kok 3c9df21ff8
Add support for compressed-tensors w8a8 int checkpoints (#2745)
* Add support for compressed-tensors w8a8 int checkpoints

This change adds a loader for w8a8 int checkpoints. One large benefit of
int8 support is that the corresponding cutlass matmul kernels also work on
compute capability 7.5.

Evaluation on neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8:

|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8431|±  |0.0100|
|               |       |strict-match    |     8|exact_match            |↑  |0.8393|±  |0.0101|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8597|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8201|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7967|±  |0.0173|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7468|±  |0.0187|

Which is the same ballpark as vLLM.

As usual, lots of thanks to Neural Magic/vLLM for the kernels.

* Always use dynamic input quantization for w8a8 int

It's far less flaky and gives better output.

* Use marlin-kernels 0.3.5

* Fix a typo

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Small fixes

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
2024-11-18 17:20:31 +01:00

174 lines
4.8 KiB
Nix

{
inputs = {
crate2nix = {
url = "github:nix-community/crate2nix";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.5";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
};
outputs =
{
self,
crate2nix,
nix-filter,
nixpkgs,
flake-utils,
rust-overlay,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
cargoNix = crate2nix.tools.${system}.appliedCargoNix {
name = "tgi";
src = ./.;
additionalCargoNixArgs = [ "--all-features" ];
};
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
overlays = [
rust-overlay.overlays.default
tgi-nix.overlays.default
(import nix/overlay.nix)
];
};
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
inherit crateOverrides;
};
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
inherit crateOverrides;
};
router =
let
routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
packagePath =
with pkgs.python3.pkgs;
makePythonPath [
protobuf
sentencepiece
torch
transformers
];
in
pkgs.writeShellApplication {
name = "text-generation-router";
text = ''
PYTHONPATH="${packagePath}" ${routerUnwrapped}/bin/text-generation-router "$@"
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in
{
checks = {
rust =
with pkgs;
rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [
clippy
pkg-config
protobuf
python3
rustfmt
];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
};
};
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
pure = mkShell {
buildInputs = [
benchmark
launcher
router
server
];
};
test = mkShell {
buildInputs =
[
benchmark
launcher
router
server
client
openssl.dev
pkg-config
cargo
rustfmt
clippy
]
++ (with python3.pkgs; [
docker
pytest
pytest-asyncio
syrupy
pre-commit
ruff
]);
};
impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};
};
packages = rec {
default = pkgs.writeShellApplication {
name = "text-generation-inference";
runtimeInputs = [
server
router
];
text = ''
${launcher}/bin/text-generation-launcher "$@"
'';
};
dockerImage = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
};
dockerImageStreamed = pkgs.callPackage nix/docker.nix {
text-generation-inference = default;
stream = true;
};
};
}
);
}