mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
feat(server): Add bitsandbytes 4bit quantization (#626)
This PR introduces bitsandbytes 4bit quantization options, `bitsandbytes-fp4` and `bitsandbytes-nf4`. I wasn't sure how to best integrate these new options with the launcher and server so I opted to add new options vs. combining the `bitsandbytes` flag with new `nf4` and `fp4` dtype flags. The latter would have meant more control flow to allow both `--quantize` and `--dtype` to be specified but only if `bitsandbytes` and `nf4` or `fp4` were provided. I built and tested using the docker image on a g4.12xlarge, Falcon-7B, and Falcon-40B. - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
This commit is contained in:
parent
3ef5ffbc64
commit
4b3e24f843
@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
|
|||||||
make run-falcon-7b-instruct-quantize
|
make run-falcon-7b-instruct-quantize
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||||
|
|
||||||
## Develop
|
## Develop
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -22,6 +22,8 @@ mod env_runtime;
|
|||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Quantization {
|
enum Quantization {
|
||||||
Bitsandbytes,
|
Bitsandbytes,
|
||||||
|
BitsandbytesNF4,
|
||||||
|
BitsandbytesFP4,
|
||||||
Gptq,
|
Gptq,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
|
|||||||
Quantization::Bitsandbytes => {
|
Quantization::Bitsandbytes => {
|
||||||
write!(f, "bitsandbytes")
|
write!(f, "bitsandbytes")
|
||||||
}
|
}
|
||||||
|
Quantization::BitsandbytesNF4 => {
|
||||||
|
write!(f, "bitsandbytes-nf4")
|
||||||
|
}
|
||||||
|
Quantization::BitsandbytesFP4 => {
|
||||||
|
write!(f, "bitsandbytes-fp4")
|
||||||
|
}
|
||||||
Quantization::Gptq => {
|
Quantization::Gptq => {
|
||||||
write!(f, "gptq")
|
write!(f, "gptq")
|
||||||
}
|
}
|
||||||
@ -96,7 +104,8 @@ struct Args {
|
|||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
|
||||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||||
/// quantization on the fly, or `gptq`.
|
/// quantization on the fly, or `gptq`. 4bit quantization is available through
|
||||||
|
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
2894
server/poetry.lock
generated
2894
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
|||||||
text-generation-server = 'text_generation_server.cli:app'
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = ">=3.9,<3.13"
|
||||||
protobuf = "^4.21.7"
|
protobuf = "^4.21.7"
|
||||||
grpcio = "^1.51.1"
|
grpcio = "^1.51.1"
|
||||||
grpcio-status = "^1.51.1"
|
grpcio-status = "^1.51.1"
|
||||||
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
|
|||||||
grpc-interceptor = "^0.15.0"
|
grpc-interceptor = "^0.15.0"
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
accelerate = { version = "^0.19.0", optional = true }
|
accelerate = { version = "^0.19.0", optional = true }
|
||||||
bitsandbytes = { version = "^0.38.1", optional = true }
|
bitsandbytes = { version = "^0.40.0", optional = true }
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
loguru = "^0.6.0"
|
loguru = "^0.6.0"
|
||||||
opentelemetry-api = "^1.15.0"
|
opentelemetry-api = "^1.15.0"
|
||||||
@ -30,6 +30,7 @@ transformers = "4.29.2"
|
|||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
scipy = "^1.11.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
|
@ -1,72 +1,45 @@
|
|||||||
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
|
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0"
|
click==8.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
datasets==2.14.0 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
dill==0.3.7 ; python_version >= "3.9" and python_version < "4.0"
|
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
|
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
numpy==1.25.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "4.0"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
|
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.25.0 ; python_version < "4.0" and python_version >= "3.9"
|
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pandas==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
pyarrow==12.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
pytz==2023.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
|
||||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
xxhash==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
|
|
||||||
|
@ -13,6 +13,8 @@ app = typer.Typer()
|
|||||||
|
|
||||||
class Quantization(str, Enum):
|
class Quantization(str, Enum):
|
||||||
bitsandbytes = "bitsandbytes"
|
bitsandbytes = "bitsandbytes"
|
||||||
|
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||||
|
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
|
|
||||||
|
|
||||||
|
@ -255,7 +255,10 @@ def get_model(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
|
raise ValueError(
|
||||||
|
"4bit quantization is not supported for AutoModel"
|
||||||
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -9,7 +9,7 @@ from typing import List
|
|||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn import Int8Params
|
from bitsandbytes.nn import Int8Params, Params4bit
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Linear4bit(nn.Module):
|
||||||
|
def __init__(self, weight, bias, quant_type):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = Params4bit(
|
||||||
|
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
||||||
|
)
|
||||||
|
self.compute_dtype = None
|
||||||
|
self.weight.cuda(weight.device)
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if getattr(self.weight, "quant_state", None) is None:
|
||||||
|
print(
|
||||||
|
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||||
|
)
|
||||||
|
inp_dtype = x.dtype
|
||||||
|
if self.compute_dtype is not None:
|
||||||
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
|
out = bnb.matmul_4bit(
|
||||||
|
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.to(inp_dtype)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
linear = FastLinear(weight, bias)
|
linear = FastLinear(weight, bias)
|
||||||
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
|
|||||||
)
|
)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
linear.bias = nn.Parameter(bias)
|
linear.bias = nn.Parameter(bias)
|
||||||
|
elif quantize == "bitsandbytes-fp4":
|
||||||
|
linear = Linear4bit(
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
quant_type="fp4",
|
||||||
|
)
|
||||||
|
elif quantize == "bitsandbytes-nf4":
|
||||||
|
linear = Linear4bit(
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
quant_type="nf4",
|
||||||
|
)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||||
|
Loading…
Reference in New Issue
Block a user