2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-01-31 17:53:56 +00:00
|
|
|
from transformers import AutoConfig
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-04-03 17:06:42 +00:00
|
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
|
|
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
2023-04-11 17:16:41 +00:00
|
|
|
from text_generation_server.models.opt import OPT, OPTSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
|
|
|
from text_generation_server.models.santacoder import SantaCoder
|
2023-03-15 12:12:49 +00:00
|
|
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.t5 import T5Sharded
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
try:
|
2023-04-19 10:52:37 +00:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
is_sm75 = major == 7 and minor == 5
|
|
|
|
is_sm8x = major == 8 and minor >= 0
|
|
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
|
|
|
|
supported = is_sm75 or is_sm8x or is_sm90
|
|
|
|
if not supported:
|
2023-04-19 19:36:59 +00:00
|
|
|
raise ImportError(
|
|
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
|
|
)
|
|
|
|
|
|
|
|
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
|
|
|
from text_generation_server.models.flash_llama import (
|
|
|
|
FlashLlama,
|
|
|
|
FlashLlamaSharded,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.flash_santacoder import (
|
|
|
|
FlashSantacoder,
|
|
|
|
FlashSantacoderSharded,
|
|
|
|
)
|
|
|
|
|
2023-04-19 10:52:37 +00:00
|
|
|
FLASH_ATTENTION = True
|
|
|
|
else:
|
|
|
|
FLASH_ATTENTION = False
|
2023-03-24 13:02:14 +00:00
|
|
|
except ImportError:
|
2023-04-19 10:51:11 +00:00
|
|
|
logger.opt(exception=True).warning(
|
|
|
|
"Could not import Flash Attention enabled models"
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
FLASH_ATTENTION = False
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
__all__ = [
|
|
|
|
"Model",
|
|
|
|
"BLOOM",
|
|
|
|
"BLOOMSharded",
|
|
|
|
"CausalLM",
|
2023-04-03 17:06:42 +00:00
|
|
|
"FlashCausalLM",
|
2023-02-07 17:25:17 +00:00
|
|
|
"Galactica",
|
|
|
|
"GalacticaSharded",
|
|
|
|
"GPTNeoxSharded",
|
2023-01-20 11:24:39 +00:00
|
|
|
"Seq2SeqLM",
|
|
|
|
"SantaCoder",
|
2023-04-11 17:16:41 +00:00
|
|
|
"OPT",
|
|
|
|
"OPTSharded",
|
2023-02-07 17:25:17 +00:00
|
|
|
"T5Sharded",
|
2023-01-20 11:24:39 +00:00
|
|
|
"get_model",
|
|
|
|
]
|
|
|
|
|
2023-04-03 17:06:42 +00:00
|
|
|
if FLASH_ATTENTION:
|
2023-03-24 13:02:14 +00:00
|
|
|
__all__.append(FlashNeoX)
|
|
|
|
__all__.append(FlashNeoXSharded)
|
2023-04-03 17:06:42 +00:00
|
|
|
__all__.append(FlashSantacoder)
|
2023-04-12 15:18:08 +00:00
|
|
|
__all__.append(FlashSantacoderSharded)
|
2023-04-11 14:38:22 +00:00
|
|
|
__all__.append(FlashLlama)
|
|
|
|
__all__.append(FlashLlamaSharded)
|
|
|
|
|
2023-04-11 17:16:41 +00:00
|
|
|
FLASH_ATT_ERROR_MESSAGE = (
|
|
|
|
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
|
|
)
|
2023-03-24 13:02:14 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
|
|
# in PyTorch 1.12 and later.
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-20 11:24:39 +00:00
|
|
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-02-07 14:38:22 +00:00
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
|
|
|
sharded: bool,
|
|
|
|
quantize: Optional[str],
|
|
|
|
trust_remote_code: bool,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2023-03-06 13:39:36 +00:00
|
|
|
if "facebook/galactica" in model_id:
|
2023-02-14 12:02:16 +00:00
|
|
|
if sharded:
|
2023-05-23 18:40:39 +00:00
|
|
|
return GalacticaSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
else:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Galactica(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_id.startswith("bigcode/"):
|
2023-04-03 17:06:42 +00:00
|
|
|
if sharded:
|
2023-04-12 15:18:08 +00:00
|
|
|
if not FLASH_ATTENTION:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
|
|
|
)
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-04-03 17:06:42 +00:00
|
|
|
else:
|
|
|
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
2023-05-23 18:40:39 +00:00
|
|
|
return santacoder_cls(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-05-23 18:40:39 +00:00
|
|
|
config = AutoConfig.from_pretrained(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
model_type = config.model_type
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
|
|
|
if sharded:
|
|
|
|
if not FLASH_ATTENTION:
|
|
|
|
raise NotImplementedError(
|
|
|
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
|
|
|
)
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashSantacoderSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
else:
|
|
|
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
2023-05-23 18:40:39 +00:00
|
|
|
return santacoder_cls(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2023-01-31 17:53:56 +00:00
|
|
|
if sharded:
|
2023-05-23 18:40:39 +00:00
|
|
|
return BLOOMSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-01-31 17:53:56 +00:00
|
|
|
else:
|
2023-05-23 18:40:39 +00:00
|
|
|
return BLOOM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "gpt_neox":
|
2022-10-28 17:24:00 +00:00
|
|
|
if sharded:
|
2023-04-03 17:06:42 +00:00
|
|
|
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
2023-05-23 18:40:39 +00:00
|
|
|
return neox_cls(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2022-10-28 17:24:00 +00:00
|
|
|
else:
|
2023-04-03 17:06:42 +00:00
|
|
|
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
2023-05-23 18:40:39 +00:00
|
|
|
return neox_cls(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
2023-04-11 14:38:22 +00:00
|
|
|
if model_type == "llama":
|
|
|
|
if sharded:
|
|
|
|
if FLASH_ATTENTION:
|
2023-05-23 18:40:39 +00:00
|
|
|
return FlashLlamaSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
2023-04-11 14:38:22 +00:00
|
|
|
else:
|
|
|
|
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
2023-05-23 18:40:39 +00:00
|
|
|
return llama_cls(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-04-11 14:38:22 +00:00
|
|
|
|
2023-04-11 17:16:41 +00:00
|
|
|
if config.model_type == "opt":
|
|
|
|
if sharded:
|
2023-05-23 18:40:39 +00:00
|
|
|
return OPTSharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
else:
|
2023-05-23 18:40:39 +00:00
|
|
|
return OPT(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-04-11 17:16:41 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "t5":
|
2023-02-07 17:25:17 +00:00
|
|
|
if sharded:
|
2023-05-23 18:40:39 +00:00
|
|
|
return T5Sharded(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-07 17:25:17 +00:00
|
|
|
else:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-02-14 12:02:16 +00:00
|
|
|
|
|
|
|
if sharded:
|
|
|
|
raise ValueError("sharded is not supported for AutoModel")
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return CausalLM(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
2023-05-23 18:40:39 +00:00
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
|
|
|
|
auto_map = getattr(config, "auto_map", None)
|
|
|
|
if trust_remote_code and auto_map is not None:
|
|
|
|
if "AutoModelForCausalLM" in auto_map.keys():
|
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
|
|
|
return Seq2SeqLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
quantize=quantize,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|