mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](f084f40bd9
).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
316 lines
9.8 KiB
Python
316 lines
9.8 KiB
Python
import os
|
|
import torch
|
|
|
|
from loguru import logger
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.models.auto import modeling_auto
|
|
from typing import Optional
|
|
|
|
from text_generation_server.models.model import Model
|
|
from text_generation_server.models.causal_lm import CausalLM
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
|
from text_generation_server.models.bloom import BLOOMSharded
|
|
from text_generation_server.models.mpt import MPTSharded
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
|
from text_generation_server.models.rw import RW
|
|
from text_generation_server.models.opt import OPTSharded
|
|
from text_generation_server.models.galactica import GalacticaSharded
|
|
from text_generation_server.models.santacoder import SantaCoder
|
|
from text_generation_server.models.t5 import T5Sharded
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
# in PyTorch 1.12 and later.
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# Disable gradients
|
|
torch.set_grad_enabled(False)
|
|
|
|
__all__ = [
|
|
"Model",
|
|
"BLOOMSharded",
|
|
"CausalLM",
|
|
"FlashCausalLM",
|
|
"GalacticaSharded",
|
|
"Seq2SeqLM",
|
|
"SantaCoder",
|
|
"OPTSharded",
|
|
"T5Sharded",
|
|
"get_model",
|
|
]
|
|
|
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
|
|
|
FLASH_ATTENTION = True
|
|
try:
|
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
|
from text_generation_server.models.flash_llama import (
|
|
FlashLlama,
|
|
)
|
|
from text_generation_server.models.flash_santacoder import (
|
|
FlashSantacoderSharded,
|
|
)
|
|
from text_generation_server.models.idefics import IDEFICSSharded
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
|
FLASH_ATTENTION = False
|
|
|
|
if FLASH_ATTENTION:
|
|
__all__.append(FlashNeoXSharded)
|
|
__all__.append(FlashRWSharded)
|
|
__all__.append(FlashSantacoderSharded)
|
|
__all__.append(FlashLlama)
|
|
__all__.append(IDEFICSSharded)
|
|
|
|
|
|
def get_model(
|
|
model_id: str,
|
|
revision: Optional[str],
|
|
sharded: bool,
|
|
quantize: Optional[str],
|
|
dtype: Optional[str],
|
|
trust_remote_code: bool,
|
|
) -> Model:
|
|
if dtype is None:
|
|
dtype = torch.float16
|
|
elif dtype == "float16":
|
|
dtype = torch.float16
|
|
elif dtype == "bfloat16":
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
|
|
|
if "facebook/galactica" in model_id:
|
|
return GalacticaSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
if model_id.startswith("bigcode/"):
|
|
if FLASH_ATTENTION:
|
|
return FlashSantacoderSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif sharded:
|
|
raise NotImplementedError(
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
)
|
|
else:
|
|
return SantaCoder(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
model_type = config_dict["model_type"]
|
|
|
|
if model_type == "gpt_bigcode":
|
|
if FLASH_ATTENTION:
|
|
return FlashSantacoderSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif sharded:
|
|
raise NotImplementedError(
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
)
|
|
else:
|
|
return SantaCoder(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
if model_type == "bloom":
|
|
return BLOOMSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif model_type == "mpt":
|
|
return MPTSharded(
|
|
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
elif model_type == "gpt_neox":
|
|
if FLASH_ATTENTION:
|
|
return FlashNeoXSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif sharded:
|
|
return GPTNeoxSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
else:
|
|
return CausalLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
elif model_type == "llama" or model_type == "baichuan":
|
|
if FLASH_ATTENTION:
|
|
return FlashLlama(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif sharded:
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
|
else:
|
|
return CausalLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
|
if sharded:
|
|
if FLASH_ATTENTION:
|
|
if config_dict.get("alibi", False):
|
|
raise NotImplementedError("sharded is not supported for this model")
|
|
return FlashRWSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
|
else:
|
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
|
return FlashRWSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
else:
|
|
return RW(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
elif model_type == "opt":
|
|
return OPTSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
elif model_type == "t5":
|
|
return T5Sharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
elif model_type == "idefics":
|
|
if FLASH_ATTENTION:
|
|
return IDEFICSSharded(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
else:
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
|
|
|
if sharded:
|
|
raise ValueError("sharded is not supported for AutoModel")
|
|
if quantize == "gptq":
|
|
raise ValueError(
|
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
)
|
|
if quantize == "awq":
|
|
raise ValueError(
|
|
"awq quantization is not supported for AutoModel"
|
|
)
|
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
|
raise ValueError(
|
|
"4bit quantization is not supported for AutoModel"
|
|
)
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
return CausalLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
return Seq2SeqLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
auto_map = config_dict.get("auto_map", None)
|
|
if trust_remote_code and auto_map is not None:
|
|
if "AutoModelForCausalLM" in auto_map.keys():
|
|
return CausalLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
|
return Seq2SeqLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|