text-generation-inference/server/text_generation_server/models
Daniël de Kok ba291dad9f
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 09:37:39 +02:00
..
custom_modeling Improve the handling of quantized weights (#2250) 2024-07-19 09:37:39 +02:00
__init__.py Falcon/DBRX: get correct number of key-value heads (#2205) 2024-07-08 13:22:38 +02:00
bloom.py Refactor dead code - Removing all flash_xxx.py files. (#2166) 2024-07-05 10:29:56 +02:00
causal_lm.py Move quantized weight handling out of the Weights class (#2194) 2024-07-09 20:04:03 +02:00
flash_causal_lm.py Move quantized weight handling out of the Weights class (#2194) 2024-07-09 20:04:03 +02:00
flash_mistral.py Refactor dead code - Removing all flash_xxx.py files. (#2166) 2024-07-05 10:29:56 +02:00
galactica.py Refactor dead code - Removing all flash_xxx.py files. (#2166) 2024-07-05 10:29:56 +02:00
globals.py [Major Change][Undecided yet] Move to FlashDecoding instead of PagedAttention kernel. (#1940) 2024-07-01 23:28:00 +02:00
idefics_causal_lm.py Enable multiple LoRa adapters (#2010) 2024-06-25 14:46:27 -04:00
idefics.py Move quantized weight handling out of the Weights class (#2194) 2024-07-09 20:04:03 +02:00
mamba.py Move quantized weight handling out of the Weights class (#2194) 2024-07-09 20:04:03 +02:00
model.py Hotfixing after refactor. 2024-07-05 09:25:29 +00:00
pali_gemma.py Refactor dead code - Removing all flash_xxx.py files. (#2166) 2024-07-05 10:29:56 +02:00
seq2seq_lm.py Move quantized weight handling out of the Weights class (#2194) 2024-07-09 20:04:03 +02:00
types.py chore: add pre-commit (#1569) 2024-02-16 11:58:58 +01:00
vlm_causal_lm.py Refactor dead code - Removing all flash_xxx.py files. (#2166) 2024-07-05 10:29:56 +02:00