mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
bitsandbytes: upgrade and enable CUDA Graphs for 4bit by default
This commit is contained in:
parent
82c24f7420
commit
5159d030a9
@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
|||||||
|
|
||||||
## Quantization
|
## Quantization
|
||||||
|
|
||||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
TGI supports [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes#bitsandbytes), [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPTQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
||||||
|
|
||||||
|
|
||||||
## RoPE Scaling
|
## RoPE Scaling
|
||||||
|
@ -12,7 +12,7 @@ Text Generation Inference implements many optimizations and features, such as:
|
|||||||
- Token streaming using Server-Sent Events (SSE)
|
- Token streaming using Server-Sent Events (SSE)
|
||||||
- Continuous batching of incoming requests for increased total throughput
|
- Continuous batching of incoming requests for increased total throughput
|
||||||
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
|
- Quantization with [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty)
|
||||||
|
@ -2075,15 +2075,8 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
(
|
(None, Some(Quantization::Bitsandbytes)) => {
|
||||||
None,
|
tracing::warn!("Bitsandbytes 8bit doesn't work with cuda graphs, deactivating them");
|
||||||
Some(
|
|
||||||
Quantization::Bitsandbytes
|
|
||||||
| Quantization::BitsandbytesNf4
|
|
||||||
| Quantization::BitsandbytesFp4,
|
|
||||||
),
|
|
||||||
) => {
|
|
||||||
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
(None, Some(Quantization::Exl2)) => {
|
(None, Some(Quantization::Exl2)) => {
|
||||||
|
13
server/poetry.lock
generated
13
server/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "accelerate"
|
name = "accelerate"
|
||||||
@ -290,22 +290,23 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitsandbytes"
|
name = "bitsandbytes"
|
||||||
version = "0.43.3"
|
version = "0.45.0"
|
||||||
description = "k-bit optimizers and matrix multiplication routines."
|
description = "k-bit optimizers and matrix multiplication routines."
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cc99507c352be0715098b2c7577b690dd158972dc4ea10c7495bac104c7c79f0"},
|
{file = "bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:0f0323de1ff1fdf8383e79bdad1283516a4c05a6fd2b44a363bf4e059422305b"},
|
||||||
{file = "bitsandbytes-0.43.3-py3-none-win_amd64.whl", hash = "sha256:257f6552f2144748a84e6c44e1f7a98f3da888f675ed74e18fd7f7eb13c6cafa"},
|
{file = "bitsandbytes-0.45.0-py3-none-win_amd64.whl", hash = "sha256:ebbf96e0ecb466716a65ecdeaef3fa1983575447b9ab66b74e5211892507c6ff"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = "*"
|
numpy = "*"
|
||||||
torch = "*"
|
torch = "*"
|
||||||
|
typing_extensions = ">=4.8.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
benchmark = ["matplotlib", "pandas"]
|
benchmark = ["matplotlib", "pandas"]
|
||||||
test = ["scipy"]
|
test = ["lion_pytorch", "scipy"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
@ -4097,4 +4098,4 @@ torch = ["torch"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "c7fdcff2b752cd3beb3995c1ecd15f0f4d9b4e117048b06ab991c6d0e0c86ff3"
|
content-hash = "767757fffcf7bec05a8a60dcfe2a3c7d258f26efac3004f3d24c8d543b462413"
|
||||||
|
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
|
|||||||
grpc-interceptor = "^0.15.4"
|
grpc-interceptor = "^0.15.4"
|
||||||
typer = "^0.12.5"
|
typer = "^0.12.5"
|
||||||
accelerate = {version = "^1.1.0", optional = true}
|
accelerate = {version = "^1.1.0", optional = true}
|
||||||
bitsandbytes = { version = "^0.43.0", optional = true }
|
bitsandbytes = { version = "^0.45.0", optional = true }
|
||||||
safetensors = "^0.4.5"
|
safetensors = "^0.4.5"
|
||||||
loguru = "^0.7.2"
|
loguru = "^0.7.2"
|
||||||
opentelemetry-api = "^1.27.0"
|
opentelemetry-api = "^1.27.0"
|
||||||
|
@ -20,23 +20,16 @@ class Linear8bitLt(torch.nn.Module):
|
|||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
has_fp16_weights=True,
|
has_fp16_weights=True,
|
||||||
memory_efficient_backward=False,
|
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
|
||||||
not memory_efficient_backward
|
|
||||||
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
# Necessary for stacked layers
|
# Necessary for stacked layers
|
||||||
self.state.threshold = threshold
|
self.state.threshold = threshold
|
||||||
self.state.has_fp16_weights = has_fp16_weights
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
self.state.memory_efficient_backward = memory_efficient_backward
|
|
||||||
if threshold > 0.0 and not has_fp16_weights:
|
|
||||||
self.state.use_pool = True
|
|
||||||
|
|
||||||
self.weight = Int8Params(
|
self.weight = Int8Params(
|
||||||
weight.data,
|
weight.data,
|
||||||
@ -63,12 +56,9 @@ class Linear8bitLt(torch.nn.Module):
|
|||||||
|
|
||||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
|
||||||
if not self.state.has_fp16_weights:
|
if not self.state.has_fp16_weights and self.state.CB is not None:
|
||||||
if self.state.CB is not None and self.state.CxB is not None:
|
self.weight.data = self.state.CB
|
||||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
|
||||||
# we no longer need the row-major weight
|
|
||||||
del self.state.CB
|
|
||||||
self.weight.data = self.state.CxB
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -106,19 +96,12 @@ class Linear4bit(torch.nn.Module):
|
|||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
if getattr(self.weight, "quant_state", None) is None:
|
|
||||||
print(
|
|
||||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
|
||||||
)
|
|
||||||
inp_dtype = x.dtype
|
inp_dtype = x.dtype
|
||||||
if self.compute_dtype is not None:
|
if self.compute_dtype is not None:
|
||||||
x = x.to(self.compute_dtype)
|
x = x.to(self.compute_dtype)
|
||||||
|
|
||||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||||
out = bnb.matmul_4bit(
|
|
||||||
|
return bnb.matmul_4bit(
|
||||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||||
)
|
).to(inp_dtype)
|
||||||
|
|
||||||
out = out.to(inp_dtype)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
Loading…
Reference in New Issue
Block a user