2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-06-01 10:07:41 +00:00
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2023-12-11 11:46:30 +00:00
|
|
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-04-03 17:06:42 +00:00
|
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOMSharded
|
2023-07-03 11:01:46 +00:00
|
|
|
from text_generation_server.models.mpt import MPTSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
2023-05-30 16:25:19 +00:00
|
|
|
from text_generation_server.models.rw import RW
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.opt import OPTSharded
|
|
|
|
from text_generation_server.models.galactica import GalacticaSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.santacoder import SantaCoder
|
|
|
|
from text_generation_server.models.t5 import T5Sharded
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
2024-01-25 14:37:53 +00:00
|
|
|
from text_generation_server.models.phi import Phi
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2023-06-19 07:53:45 +00:00
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
|
|
# in PyTorch 1.12 and later.
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"Model",
|
|
|
|
"BLOOMSharded",
|
|
|
|
"CausalLM",
|
|
|
|
"FlashCausalLM",
|
|
|
|
"GalacticaSharded",
|
|
|
|
"Seq2SeqLM",
|
|
|
|
"SantaCoder",
|
|
|
|
"OPTSharded",
|
|
|
|
"T5Sharded",
|
|
|
|
"get_model",
|
|
|
|
]
|
|
|
|
|
2023-07-18 14:21:18 +00:00
|
|
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
2023-06-19 07:53:45 +00:00
|
|
|
|
2023-07-18 14:21:18 +00:00
|
|
|
FLASH_ATTENTION = True
|
2023-03-24 13:02:14 +00:00
|
|
|
try:
|
2023-07-18 14:21:18 +00:00
|
|
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
|
|
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
|
|
|
from text_generation_server.models.flash_llama import (
|
|
|
|
FlashLlama,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.flash_santacoder import (
|
|
|
|
FlashSantacoderSharded,
|
2023-04-19 10:51:11 +00:00
|
|
|
)
|
2023-08-17 12:38:49 +00:00
|
|
|
from text_generation_server.models.idefics import IDEFICSSharded
|
2023-12-15 11:52:24 +00:00
|
|
|
from text_generation_server.models.flash_mistral import FlashMistral
|
|
|
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
2024-01-25 14:37:53 +00:00
|
|
|
from text_generation_server.models.flash_phi import FlashPhi
|
2023-12-15 11:52:24 +00:00
|
|
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
2023-07-18 14:21:18 +00:00
|
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
2023-04-03 17:06:42 +00:00
|
|
|
FLASH_ATTENTION = False
|
2023-12-15 11:52:24 +00:00
|
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-04-03 17:06:42 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-03-24 13:02:14 +00:00
|
|
|
__all__.append(FlashNeoXSharded)
|
2023-05-30 16:25:19 +00:00
|
|
|
__all__.append(FlashRWSharded)
|
2023-04-12 15:18:08 +00:00
|
|
|
__all__.append(FlashSantacoderSharded)
|
2023-04-11 14:38:22 +00:00
|
|
|
__all__.append(FlashLlama)
|
2023-08-17 12:38:49 +00:00
|
|
|
__all__.append(IDEFICSSharded)
|
2023-09-28 07:55:47 +00:00
|
|
|
__all__.append(FlashMistral)
|
2023-12-11 13:43:40 +00:00
|
|
|
__all__.append(FlashMixtral)
|
2024-01-25 14:37:53 +00:00
|
|
|
__all__.append(FlashPhi)
|
2023-12-11 13:43:40 +00:00
|
|
|
|
2024-02-08 09:19:45 +00:00
|
|
|
MAMBA_AVAILABLE = True
|
|
|
|
try:
|
|
|
|
from text_generation_server.models.mamba import Mamba
|
|
|
|
except ImportError as e:
|
|
|
|
logger.warning(f"Could not import Mamba: {e}")
|
|
|
|
MAMBA_AVAILABLE = False
|
|
|
|
|
|
|
|
if MAMBA_AVAILABLE:
|
|
|
|
__all__.append(Mamba)
|
2023-12-11 13:43:40 +00:00
|
|
|
|
2024-02-08 17:41:25 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
|
|
|
sharded: bool,
|
|
|
|
quantize: Optional[str],
|
2023-12-11 11:46:30 +00:00
|
|
|
speculate: Optional[int],
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype: Optional[str],
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code: bool,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2023-06-30 18:30:09 +00:00
|
|
|
if dtype is None:
|
2023-11-28 16:54:26 +00:00
|
|
|
# Keep it as default for now and let
|
|
|
|
# every model resolve their own default dtype.
|
|
|
|
dtype = None
|
2023-06-30 18:30:09 +00:00
|
|
|
elif dtype == "float16":
|
|
|
|
dtype = torch.float16
|
|
|
|
elif dtype == "bfloat16":
|
|
|
|
dtype = torch.bfloat16
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
|
|
|
|
2023-12-11 11:46:30 +00:00
|
|
|
if speculate is not None:
|
|
|
|
set_speculate(speculate)
|
|
|
|
else:
|
|
|
|
set_speculate(0)
|
|
|
|
|
2023-03-06 13:39:36 +00:00
|
|
|
if "facebook/galactica" in model_id:
|
2023-06-08 12:51:52 +00:00
|
|
|
return GalacticaSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
2023-08-08 10:09:51 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_id.startswith("bigcode/"):
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-01 17:49:13 +00:00
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-12-11 11:46:30 +00:00
|
|
|
|
|
|
|
use_medusa = None
|
|
|
|
if "medusa_num_heads" in config_dict:
|
|
|
|
use_medusa = model_id
|
|
|
|
model_id = config_dict["base_model_name_or_path"]
|
|
|
|
revision = "main"
|
|
|
|
speculate_medusa = config_dict["medusa_num_heads"]
|
|
|
|
if speculate is not None:
|
|
|
|
if speculate > speculate_medusa:
|
2023-12-11 13:49:52 +00:00
|
|
|
raise RuntimeError(
|
|
|
|
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
|
|
|
)
|
2023-12-11 11:46:30 +00:00
|
|
|
else:
|
|
|
|
set_speculate(speculate)
|
|
|
|
else:
|
|
|
|
set_speculate(speculate_medusa)
|
|
|
|
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
method = "medusa"
|
|
|
|
else:
|
|
|
|
method = "n-gram"
|
|
|
|
|
|
|
|
speculate = get_speculate()
|
|
|
|
if speculate > 0:
|
|
|
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
|
|
|
|
2024-02-08 09:19:45 +00:00
|
|
|
model_type = config_dict.get("model_type", None)
|
|
|
|
if model_type is None:
|
|
|
|
# TODO: fix how we determine model type for Mamba
|
|
|
|
if "ssm_cfg" in config_dict:
|
|
|
|
# *only happens in Mamba case
|
|
|
|
model_type = "ssm"
|
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Could not determine model type for {model_id} revision {revision}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if model_type == "ssm":
|
|
|
|
return Mamba(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2023-06-08 12:51:52 +00:00
|
|
|
return BLOOMSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-07-03 11:01:46 +00:00
|
|
|
elif model_type == "mpt":
|
|
|
|
return MPTSharded(
|
2023-09-27 10:22:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-07-03 11:01:46 +00:00
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
|
|
|
|
elif model_type == "gpt_neox":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashNeoXSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-06-08 12:51:52 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
elif sharded:
|
|
|
|
return GPTNeoxSharded(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-01-31 17:53:56 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2024-01-26 18:04:57 +00:00
|
|
|
|
2024-01-25 14:37:53 +00:00
|
|
|
elif model_type == "phi":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashPhi(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
use_medusa=use_medusa,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
|
|
|
elif model_type == "phi-msft":
|
|
|
|
if FLASH_ATTENTION:
|
2024-01-26 18:04:57 +00:00
|
|
|
raise NotImplementedError(
|
|
|
|
"Legacy phi-msft is not supported with Flash Attention"
|
|
|
|
)
|
2024-01-25 14:37:53 +00:00
|
|
|
else:
|
|
|
|
return Phi(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 14:51:34 +00:00
|
|
|
elif model_type == "llama" or model_type == "baichuan":
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashLlama(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
2023-12-11 13:49:52 +00:00
|
|
|
use_medusa=use_medusa,
|
2023-05-23 18:40:39 +00:00
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
2022-10-28 17:24:00 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-07-27 16:38:57 +00:00
|
|
|
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
2023-05-30 16:25:19 +00:00
|
|
|
if sharded:
|
|
|
|
if FLASH_ATTENTION:
|
2023-07-27 16:38:57 +00:00
|
|
|
if config_dict.get("alibi", False):
|
2023-05-30 16:25:19 +00:00
|
|
|
raise NotImplementedError("sharded is not supported for this model")
|
|
|
|
return FlashRWSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-07-27 16:38:57 +00:00
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
2023-05-30 16:25:19 +00:00
|
|
|
else:
|
2023-06-01 10:07:41 +00:00
|
|
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
2023-06-08 12:51:52 +00:00
|
|
|
return FlashRWSharded(
|
2023-05-30 16:25:19 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return RW(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-30 16:25:19 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
2023-09-28 07:55:47 +00:00
|
|
|
if model_type == "mistral":
|
2023-12-15 13:56:17 +00:00
|
|
|
sliding_window = config_dict.get("sliding_window", -1)
|
|
|
|
if (
|
|
|
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
|
|
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
2023-09-28 07:55:47 +00:00
|
|
|
return FlashMistral(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-12-11 13:43:40 +00:00
|
|
|
|
|
|
|
if model_type == "mixtral":
|
2023-12-15 13:56:17 +00:00
|
|
|
sliding_window = config_dict.get("sliding_window", -1)
|
|
|
|
if (
|
|
|
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
|
|
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
2023-12-11 13:43:40 +00:00
|
|
|
return FlashMixtral(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-09-28 07:55:47 +00:00
|
|
|
|
|
|
|
if model_type == "opt":
|
2023-06-08 12:51:52 +00:00
|
|
|
return OPTSharded(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-06-08 12:51:52 +00:00
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
|
2023-09-28 07:55:47 +00:00
|
|
|
if model_type == "t5":
|
2023-06-20 09:06:10 +00:00
|
|
|
return T5Sharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-06-20 09:06:10 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-09-28 07:55:47 +00:00
|
|
|
if model_type == "idefics":
|
2023-08-17 12:38:49 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-09-27 10:22:09 +00:00
|
|
|
return IDEFICSSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-08-17 12:38:49 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
2023-02-14 12:02:16 +00:00
|
|
|
|
|
|
|
if sharded:
|
2023-12-15 11:52:24 +00:00
|
|
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 10:27:01 +00:00
|
|
|
if quantize == "gptq":
|
2023-12-15 11:52:24 +00:00
|
|
|
raise NotImplementedError(
|
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 10:27:01 +00:00
|
|
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
|
|
)
|
Add AWQ quantization inference support (#1019) (#1054)
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
2023-09-25 13:31:27 +00:00
|
|
|
if quantize == "awq":
|
2023-12-15 11:52:24 +00:00
|
|
|
raise NotImplementedError("awq quantization is not supported for AutoModel")
|
2023-08-03 21:00:59 +00:00
|
|
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
2023-12-15 11:52:24 +00:00
|
|
|
raise NotImplementedError("4bit quantization is not supported for AutoModel")
|
2023-12-11 13:49:52 +00:00
|
|
|
elif quantize == "eetq":
|
2023-12-15 11:52:24 +00:00
|
|
|
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return CausalLM(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-05-23 18:40:39 +00:00
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
2023-06-30 18:30:09 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-05-23 18:40:39 +00:00
|
|
|
)
|
|
|
|
|
2023-06-01 10:07:41 +00:00
|
|
|
auto_map = config_dict.get("auto_map", None)
|
2023-05-23 18:40:39 +00:00
|
|
|
if trust_remote_code and auto_map is not None:
|
|
|
|
if "AutoModelForCausalLM" in auto_map.keys():
|
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-26 10:31:47 +00:00
|
|
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
2023-06-30 18:30:09 +00:00
|
|
|
dtype=dtype,
|
2023-05-23 18:40:39 +00:00
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|