2023-06-08 12:51:52 +00:00
|
|
|
import os
|
2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-06-01 10:07:41 +00:00
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-04-03 17:06:42 +00:00
|
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOMSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
2023-05-30 16:25:19 +00:00
|
|
|
from text_generation_server.models.rw import RW
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.opt import OPTSharded
|
|
|
|
from text_generation_server.models.galactica import GalacticaSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.santacoder import SantaCoder
|
|
|
|
from text_generation_server.models.t5 import T5Sharded
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
try:
|
2023-06-08 12:51:52 +00:00
|
|
|
if (
|
|
|
|
torch.cuda.is_available()
|
|
|
|
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
|
|
|
):
|
2023-04-19 10:52:37 +00:00
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
is_sm75 = major == 7 and minor == 5
|
|
|
|
is_sm8x = major == 8 and minor >= 0
|
|
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
|
|
|
|
supported = is_sm75 or is_sm8x or is_sm90
|
|
|
|
if not supported:
|
2023-04-19 19:36:59 +00:00
|
|
|
raise ImportError(
|
|
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
|
|
)
|
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
|
|
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
2023-04-19 19:36:59 +00:00
|
|
|
from text_generation_server.models.flash_llama import (
|
|
|
|
FlashLlama,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.flash_santacoder import (
|
|
|
|
FlashSantacoderSharded,
|
|
|
|
)
|
|
|
|
|
2023-04-19 10:52:37 +00:00
|
|
|
FLASH_ATTENTION = True
|
|
|
|
else:
|
|
|
|
FLASH_ATTENTION = False
|
2023-03-24 13:02:14 +00:00
|
|
|
except ImportError:
|
2023-04-19 10:51:11 +00:00
|
|
|
logger.opt(exception=True).warning(
|
|
|
|
"Could not import Flash Attention enabled models"
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
FLASH_ATTENTION = False
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
__all__ = [
|
|
|
|
"Model",
|
|
|
|
"BLOOMSharded",
|
|
|
|
"CausalLM",
|
2023-04-03 17:06:42 +00:00
|
|
|
"FlashCausalLM",
|
2023-02-07 17:25:17 +00:00
|
|
|
"GalacticaSharded",
|
2023-01-20 11:24:39 +00:00
|
|
|
"Seq2SeqLM",
|
|
|
|
"SantaCoder",
|
2023-04-11 17:16:41 +00:00
|
|
|
"OPTSharded",
|
2023-02-07 17:25:17 +00:00
|
|
|
"T5Sharded",
|
2023-01-20 11:24:39 +00:00
|
|
|
"get_model",
|
|
|
|
]
|
|
|
|
|
2023-04-03 17:06:42 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-03-24 13:02:14 +00:00
|
|
|
__all__.append(FlashNeoXSharded)
|
2023-05-30 16:25:19 +00:00
|
|
|
__all__.append(FlashRWSharded)
|
2023-04-12 15:18:08 +00:00
|
|
|
__all__.append(FlashSantacoderSharded)
|
2023-04-11 14:38:22 +00:00
|
|
|
__all__.append(FlashLlama)
|
|
|
|
|
2023-04-11 17:16:41 +00:00
|
|
|
FLASH_ATT_ERROR_MESSAGE = (
|
|
|
|
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
|
|
)
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
|
|
# in PyTorch 1.12 and later.
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-02-07 14:38:22 +00:00
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
|
|
|
sharded: bool,
|
|
|
|
quantize: Optional[str],
|
|
|
|
trust_remote_code: bool,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2023-03-06 13:39:36 +00:00
|
|
|
if "facebook/galactica" in model_id:
|
2023-06-08 12:51:52 +00:00
|
|
|
return GalacticaSharded(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_id.startswith("bigcode/"):
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-01 17:49:13 +00:00
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-06-01 10:07:41 +00:00
|
|
|
model_type = config_dict["model_type"]
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
2023-06-08 12:51:52 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return SantaCoder(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2023-06-08 12:51:52 +00:00
|
|
|
return BLOOMSharded(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
|
|
|
|
elif model_type == "gpt_neox":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashNeoXSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
elif sharded:
|
|
|
|
return GPTNeoxSharded(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-01-31 17:53:56 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "llama":
|
|
|
|
if FLASH_ATTENTION:
|
|
|
|
return FlashLlama(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-06-08 12:51:52 +00:00
|
|
|
elif sharded:
|
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
2022-10-28 17:24:00 +00:00
|
|
|
else:
|
2023-06-08 12:51:52 +00:00
|
|
|
return CausalLM(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-30 16:25:19 +00:00
|
|
|
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
|
|
|
if sharded:
|
|
|
|
if FLASH_ATTENTION:
|
2023-06-01 10:07:41 +00:00
|
|
|
if config_dict.get("alibi", False) or (
|
|
|
|
model_type == "RefinedWebModel"
|
|
|
|
and config_dict.get("multi_query", True)
|
2023-05-30 16:25:19 +00:00
|
|
|
):
|
|
|
|
raise NotImplementedError("sharded is not supported for this model")
|
|
|
|
return FlashRWSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
|
|
|
)
|
|
|
|
else:
|
2023-06-01 10:07:41 +00:00
|
|
|
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
2023-06-08 12:51:52 +00:00
|
|
|
return FlashRWSharded(
|
2023-05-30 16:25:19 +00:00
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return RW(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "opt":
|
|
|
|
return OPTSharded(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
|
2023-06-08 12:51:52 +00:00
|
|
|
elif model_type == "t5":
|
2023-02-07 17:25:17 +00:00
|
|
|
if sharded:
|
2023-05-23 18:40:39 +00:00
|
|
|
return T5Sharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-07 17:25:17 +00:00
|
|
|
else:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
|
|
|
if sharded:
|
|
|
|
raise ValueError("sharded is not supported for AutoModel")
|
2023-06-12 15:57:32 +00:00
|
|
|
if quantize == "gptq":
|
2023-06-13 11:45:08 +00:00
|
|
|
raise ValueError(
|
|
|
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return CausalLM(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
|
2023-06-01 10:07:41 +00:00
|
|
|
auto_map = config_dict.get("auto_map", None)
|
2023-05-23 18:40:39 +00:00
|
|
|
if trust_remote_code and auto_map is not None:
|
|
|
|
if "AutoModelForCausalLM" in auto_map.keys():
|
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-26 10:31:47 +00:00
|
|
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|