diff --git a/README.md b/README.md index 869cc6682..0ba496753 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme make run-falcon-7b-instruct-quantize ``` +4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. + ## Develop ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 146d83d6b..757627120 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -22,6 +22,8 @@ mod env_runtime; #[derive(Clone, Copy, Debug, ValueEnum)] enum Quantization { Bitsandbytes, + BitsandbytesNF4, + BitsandbytesFP4, Gptq, } @@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization { Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } + Quantization::BitsandbytesNF4 => { + write!(f, "bitsandbytes-nf4") + } + Quantization::BitsandbytesFP4 => { + write!(f, "bitsandbytes-fp4") + } Quantization::Gptq => { write!(f, "gptq") } @@ -116,7 +124,8 @@ struct Args { num_shard: Option, /// Whether you want the model to be quantized. This will use `bitsandbytes` for - /// quantization on the fly, or `gptq`. + /// quantization on the fly, or `gptq`. 4bit quantization is available through + /// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options. #[clap(long, env, value_enum)] quantize: Option, diff --git a/server/poetry.lock b/server/poetry.lock index 70f09f325..f447914a9 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -192,13 +192,13 @@ files = [ [[package]] name = "bitsandbytes" -version = "0.38.1" -description = "8-bit optimizers and matrix multiplication routines." +version = "0.40.2" +description = "k-bit optimizers and matrix multiplication routines." optional = true python-versions = "*" files = [ - {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, - {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, + {file = "bitsandbytes-0.40.2-py3-none-any.whl", hash = "sha256:f0ae26f40c9230c9add9e7c70a10a5ced36fb6deff39906aec1ce4fd25e6ddc0"}, + {file = "bitsandbytes-0.40.2.tar.gz", hash = "sha256:808ac966272c63bccb2be6d77365275a4c28f1fa348d33656e670de3cab40fc4"}, ] [[package]] @@ -1751,6 +1751,42 @@ tensorflow = ["tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"] torch = ["torch (>=1.10)"] +[[package]] +name = "scipy" +version = "1.11.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = "<3.13,>=3.9" +files = [ + {file = "scipy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aec8c62fbe52914f9cf28d846cf0401dd80ab80788bbab909434eb336ed07c04"}, + {file = "scipy-1.11.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3b9963798df1d8a52db41a6fc0e6fa65b1c60e85d73da27ae8bb754de4792481"}, + {file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e8eb42db36526b130dfbc417609498a6192381abc1975b91e3eb238e0b41c1a"}, + {file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:366a6a937110d80dca4f63b3f5b00cc89d36f678b2d124a01067b154e692bab1"}, + {file = "scipy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:08d957ca82d3535b3b9ba6c8ff355d78fe975271874e2af267cb5add5bd78625"}, + {file = "scipy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:e866514bc2d660608447b6ba95c8900d591f2865c07cca0aa4f7ff3c4ca70f30"}, + {file = "scipy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba94eeef3c9caa4cea7b402a35bb02a5714ee1ee77eb98aca1eed4543beb0f4c"}, + {file = "scipy-1.11.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:512fdc18c65f76dadaca139348e525646d440220d8d05f6d21965b8d4466bccd"}, + {file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cce154372f0ebe88556ed06d7b196e9c2e0c13080ecb58d0f35062dc7cc28b47"}, + {file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4bb943010203465ac81efa392e4645265077b4d9e99b66cf3ed33ae12254173"}, + {file = "scipy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:249cfa465c379c9bb2c20123001e151ff5e29b351cbb7f9c91587260602c58d0"}, + {file = "scipy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:ffb28e3fa31b9c376d0fb1f74c1f13911c8c154a760312fbee87a21eb21efe31"}, + {file = "scipy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:39154437654260a52871dfde852adf1b93b1d1bc5dc0ffa70068f16ec0be2624"}, + {file = "scipy-1.11.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b588311875c58d1acd4ef17c983b9f1ab5391755a47c3d70b6bd503a45bfaf71"}, + {file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d51565560565a0307ed06fa0ec4c6f21ff094947d4844d6068ed04400c72d0c3"}, + {file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b41a0f322b4eb51b078cb3441e950ad661ede490c3aca66edef66f4b37ab1877"}, + {file = "scipy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:396fae3f8c12ad14c5f3eb40499fd06a6fef8393a6baa352a652ecd51e74e029"}, + {file = "scipy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:be8c962a821957fdde8c4044efdab7a140c13294997a407eaee777acf63cbf0c"}, + {file = "scipy-1.11.1.tar.gz", hash = "sha256:fb5b492fa035334fd249f0973cc79ecad8b09c604b42a127a677b45a9a3d4289"}, +] + +[package.dependencies] +numpy = ">=1.21.6,<1.28.0" + +[package.extras] +dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] +test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "sentencepiece" version = "0.1.99" @@ -2425,5 +2461,5 @@ quantize = ["accelerate", "datasets", "texttable"] [metadata] lock-version = "2.0" -python-versions = "^3.9" -content-hash = "93fd0873b3e16c10b67a84216a84f5eb2f5067cb3ff8cb912446cc6a7fa9c030" +python-versions = ">=3.9,<3.13" +content-hash = "2abb80833b678452cfc73464fc5b2e48d74b2672bd987240041a33c724a74000" diff --git a/server/pyproject.toml b/server/pyproject.toml index 56afda0a7..88b93f4e5 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -8,7 +8,7 @@ authors = ["Olivier Dehaene "] text-generation-server = 'text_generation_server.cli:app' [tool.poetry.dependencies] -python = "^3.9" +python = ">=3.9,<3.13" protobuf = "^4.21.7" grpcio = "^1.51.1" grpcio-status = "^1.51.1" @@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" accelerate = { version = "^0.19.0", optional = true } -bitsandbytes = { version = "^0.38.1", optional = true } +bitsandbytes = { version = "^0.40.0", optional = true } safetensors = "0.3.1" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" @@ -32,6 +32,7 @@ texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } peft = "^0.4.0" torch = {version = "^2.0.1+cu118", source = "pytorch-gpu-src"} +scipy = "^1.11.1" [tool.poetry.extras] accelerate = ["accelerate"] diff --git a/server/requirements.txt b/server/requirements.txt index f9414b25e..beab4bf6c 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,58 +1,59 @@ --extra-index-url https://download.pytorch.org/whl/cu118 -accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0" -backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" -certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0" -charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0" -click==8.1.6 ; python_version >= "3.9" and python_version < "4.0" -cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" -einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0" -filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" -fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" -googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "4.0" -grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "4.0" -grpcio==1.56.2 ; python_version >= "3.9" and python_version < "4.0" -hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" -huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" -idna==3.4 ; python_version >= "3.9" and python_version < "4.0" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0" -lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0" -networkx==3.1 ; python_version >= "3.9" and python_version < "4.0" -numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0" -packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" -peft==0.4.0 ; python_version >= "3.9" and python_version < "4.0" -protobuf==4.23.4 ; python_version >= "3.9" and python_version < "4.0" -psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0" -regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0" -requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" -safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" -sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" -setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0" -sympy==1.12 ; python_version >= "3.9" and python_version < "4.0" -tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" -torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "4.0" -tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" -transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" -triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0" -typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" -typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0" -urllib3==2.0.4 ; python_version >= "3.9" and python_version < "4.0" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" -wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" +accelerate==0.19.0 ; python_version >= "3.9" and python_version < "3.13" +backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" +certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" +charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" +click==8.1.6 ; python_version >= "3.9" and python_version < "3.13" +cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") +deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" +einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" +grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13" +idna==3.4 ; python_version >= "3.9" and python_version < "3.13" +jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" +lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" +loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" +mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" +networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" +numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" +opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" +packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" +peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13" +protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13" +psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13" +pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" +regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13" +requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13" +scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13" +sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" +setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13" +sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" +tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" +torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13" +triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" +typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" +win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" +wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index eba807bc6..459ba8c4f 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -14,6 +14,8 @@ app = typer.Typer() class Quantization(str, Enum): bitsandbytes = "bitsandbytes" + bitsandbytes_nf4 = "bitsandbytes-nf4" + bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e9260eede..71efcab74 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -255,7 +255,10 @@ def get_model( raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - + elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): + raise ValueError( + "4bit quantization is not supported for AutoModel" + ) if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index d7b4c0cc8..97257f957 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -9,7 +9,7 @@ from typing import List HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params + from bitsandbytes.nn import Int8Params, Params4bit except ImportError: HAS_BITS_AND_BYTES = False @@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module): return out +class Linear4bit(nn.Module): + def __init__(self, weight, bias, quant_type): + super().__init__() + self.weight = Params4bit( + weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + ) + self.compute_dtype = None + self.weight.cuda(weight.device) + self.bias = bias + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, "quant_state", None) is None: + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." + ) + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit( + x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state + ) + + out = out.to(inp_dtype) + + return out + + def get_linear(weight, bias, quantize): if quantize is None: linear = FastLinear(weight, bias) @@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize): ) if bias is not None: linear.bias = nn.Parameter(bias) + elif quantize == "bitsandbytes-fp4": + linear = Linear4bit( + weight, + bias, + quant_type="fp4", + ) + elif quantize == "bitsandbytes-nf4": + linear = Linear4bit( + weight, + bias, + quant_type="nf4", + ) elif quantize == "gptq": try: qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight