2023-06-08 12:51:52 +00:00
|
|
|
import os
|
2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-06-01 10:07:41 +00:00
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-04-03 17:06:42 +00:00
|
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOMSharded
|
2023-07-03 11:01:46 +00:00
|
|
|
from text_generation_server.models.mpt import MPTSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
2023-05-30 16:25:19 +00:00
|
|
|
from text_generation_server.models.rw import RW
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.opt import OPTSharded
|
|
|
|
from text_generation_server.models.galactica import GalacticaSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.santacoder import SantaCoder
|
|
|
|
from text_generation_server.models.t5 import T5Sharded
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2023-06-19 07:53:45 +00:00
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
|
|
# in PyTorch 1.12 and later.
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"Model",
|
|
|
|
"BLOOMSharded",
|
|
|
|
"CausalLM",
|
|
|
|
"FlashCausalLM",
|
|
|
|
"GalacticaSharded",
|
|
|
|
"Seq2SeqLM",
|
|
|
|
"SantaCoder",
|
|
|
|
"OPTSharded",
|
|
|
|
"T5Sharded",
|
|
|
|
"get_model",
|
|
|
|
]
|
|
|
|
|
2023-07-18 14:21:18 +00:00
|
|
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
2023-06-19 07:53:45 +00:00
|
|
|
|
2023-07-18 14:21:18 +00:00
|
|
|
FLASH_ATTENTION = True
|
2023-03-24 13:02:14 +00:00
|
|
|
try:
|
2023-07-18 14:21:18 +00:00
|
|
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
|
|
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
|
|
|
from text_generation_server.models.flash_llama import (
|
|
|
|
FlashLlama,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.flash_santacoder import (
|
|
|
|
FlashSantacoderSharded,
|
2023-04-19 10:51:11 +00:00
|
|
|
)
|
2023-07-18 14:21:18 +00:00
|
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
2023-04-03 17:06:42 +00:00
|
|
|
FLASH_ATTENTION = False
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-04-03 17:06:42 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-03-24 13:02:14 +00:00
|
|
|
__all__.append(FlashNeoXSharded)
|
2023-05-30 16:25:19 +00:00
|
|
|
__all__.append(FlashRWSharded)
|
2023-04-12 15:18:08 +00:00
|
|
|
__all__.append(FlashSantacoderSharded)
|
2023-04-11 14:38:22 +00:00
|
|
|
__all__.append(FlashLlama)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
|
|
|
sharded: bool,
|
|
|
|
quantize: Optional[str],
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype: Optional[str],
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code: bool,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2023-06-30 18:30:09 +00:00
|
|
|
if dtype is None:
|
|
|
|
dtype = torch.float16
|
|
|
|
elif dtype == "float16":
|
|
|
|
dtype = torch.float16
|
|
|
|
elif dtype == "bfloat16":
|
|
|
|
dtype = torch.bfloat16
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
|
|
|
|
2023-03-06 13:39:36 +00:00
|
|
|
if "facebook/galactica" in model_id:
|
2023-06-08 12:51:52 +00:00
|
|
|
return GalacticaSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
dtypetrust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_id.startswith("bigcode/"):
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-01 17:49:13 +00:00
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-06-01 10:07:41 +00:00
|
|
|
model_type = config_dict["model_type"]
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2023-06-08 12:51:52 +00:00
|
|
|
return BLOOMSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-07-03 11:01:46 +00:00
|
|
|
elif model_type == "mpt":
|
|
|
|
return MPTSharded(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
|
|
|
|
elif model_type == "gpt_neox":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashNeoXSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-06-08 12:51:52 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
elif sharded:
|
|
|
|
return GPTNeoxSharded(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-01-31 17:53:56 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "llama":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashLlama(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
2022-10-28 17:24:00 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-07-27 16:38:57 +00:00
|
|
|
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
2023-05-30 16:25:19 +00:00
|
|
|
if sharded:
|
|
|
|
if FLASH_ATTENTION:
|
2023-07-27 16:38:57 +00:00
|
|
|
if config_dict.get("alibi", False):
|
2023-05-30 16:25:19 +00:00
|
|
|
raise NotImplementedError("sharded is not supported for this model")
|
|
|
|
return FlashRWSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-07-27 16:38:57 +00:00
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
2023-05-30 16:25:19 +00:00
|
|
|
else:
|
2023-06-01 10:07:41 +00:00
|
|
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
2023-06-08 12:51:52 +00:00
|
|
|
return FlashRWSharded(
|
2023-05-30 16:25:19 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return RW(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "opt":
|
|
|
|
return OPTSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "t5":
|
2023-06-20 09:06:10 +00:00
|
|
|
return T5Sharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-06-20 09:06:10 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
|
|
|
if sharded:
|
|
|
|
raise ValueError("sharded is not supported for AutoModel")
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 10:27:01 +00:00
|
|
|
if quantize == "gptq":
|
|
|
|
raise ValueError(
|
|
|
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return CausalLM(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-05-23 18:40:39 +00:00
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-05-23 18:40:39 +00:00
|
|
|
)
|
|
|
|
|
2023-06-01 10:07:41 +00:00
|
|
|
auto_map = config_dict.get("auto_map", None)
|
2023-05-23 18:40:39 +00:00
|
|
|
if trust_remote_code and auto_map is not None:
|
|
|
|
if "AutoModelForCausalLM" in auto_map.keys():
|
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-26 10:31:47 +00:00
|
|
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|