2023-06-08 12:51:52 +00:00
import os
2023-01-20 11:24:39 +00:00
import torch
2023-03-24 13:02:14 +00:00
from loguru import logger
2023-06-01 10:07:41 +00:00
from transformers . configuration_utils import PretrainedConfig
2023-03-27 07:23:22 +00:00
from transformers . models . auto import modeling_auto
2023-01-31 17:53:56 +00:00
from typing import Optional
2023-03-07 17:52:22 +00:00
from text_generation_server . models . model import Model
from text_generation_server . models . causal_lm import CausalLM
2023-04-03 17:06:42 +00:00
from text_generation_server . models . flash_causal_lm import FlashCausalLM
2023-06-08 12:51:52 +00:00
from text_generation_server . models . bloom import BLOOMSharded
2023-03-07 17:52:22 +00:00
from text_generation_server . models . seq2seq_lm import Seq2SeqLM
2023-05-30 16:25:19 +00:00
from text_generation_server . models . rw import RW
2023-06-08 12:51:52 +00:00
from text_generation_server . models . opt import OPTSharded
from text_generation_server . models . galactica import GalacticaSharded
2023-03-07 17:52:22 +00:00
from text_generation_server . models . santacoder import SantaCoder
from text_generation_server . models . t5 import T5Sharded
2023-06-08 12:51:52 +00:00
from text_generation_server . models . gpt_neox import GPTNeoxSharded
2023-01-20 11:24:39 +00:00
2023-03-24 13:02:14 +00:00
try :
2023-06-08 12:51:52 +00:00
if (
torch . cuda . is_available ( )
and not os . getenv ( " USE_FLASH_ATTENTION " , " " ) . lower ( ) == " false "
) :
2023-04-19 10:52:37 +00:00
major , minor = torch . cuda . get_device_capability ( )
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor > = 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
if not supported :
2023-04-19 19:36:59 +00:00
raise ImportError (
f " GPU with CUDA capability { major } { minor } is not supported "
)
2023-06-08 12:51:52 +00:00
from text_generation_server . models . flash_rw import FlashRWSharded
from text_generation_server . models . flash_neox import FlashNeoXSharded
2023-04-19 19:36:59 +00:00
from text_generation_server . models . flash_llama import (
FlashLlama ,
)
from text_generation_server . models . flash_santacoder import (
FlashSantacoderSharded ,
)
2023-04-19 10:52:37 +00:00
FLASH_ATTENTION = True
else :
FLASH_ATTENTION = False
2023-03-24 13:02:14 +00:00
except ImportError :
2023-04-19 10:51:11 +00:00
logger . opt ( exception = True ) . warning (
" Could not import Flash Attention enabled models "
)
2023-04-03 17:06:42 +00:00
FLASH_ATTENTION = False
2023-03-24 13:02:14 +00:00
2023-01-20 11:24:39 +00:00
__all__ = [
" Model " ,
" BLOOMSharded " ,
" CausalLM " ,
2023-04-03 17:06:42 +00:00
" FlashCausalLM " ,
2023-02-07 17:25:17 +00:00
" GalacticaSharded " ,
2023-01-20 11:24:39 +00:00
" Seq2SeqLM " ,
" SantaCoder " ,
2023-04-11 17:16:41 +00:00
" OPTSharded " ,
2023-02-07 17:25:17 +00:00
" T5Sharded " ,
2023-01-20 11:24:39 +00:00
" get_model " ,
]
2023-04-03 17:06:42 +00:00
if FLASH_ATTENTION :
2023-03-24 13:02:14 +00:00
__all__ . append ( FlashNeoXSharded )
2023-05-30 16:25:19 +00:00
__all__ . append ( FlashRWSharded )
2023-04-12 15:18:08 +00:00
__all__ . append ( FlashSantacoderSharded )
2023-04-11 14:38:22 +00:00
__all__ . append ( FlashLlama )
2023-04-11 17:16:41 +00:00
FLASH_ATT_ERROR_MESSAGE = (
" {} requires Flash Attention CUDA kernels to be installed. \n "
" Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
" or install flash attention with `cd server && make install install-flash-attention` "
)
2023-03-24 13:02:14 +00:00
2023-01-20 11:24:39 +00:00
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch . backends . cuda . matmul . allow_tf32 = True
2022-10-28 17:24:00 +00:00
2023-01-20 11:24:39 +00:00
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch . backends . cudnn . allow_tf32 = True
2022-10-28 17:24:00 +00:00
2023-02-07 14:38:22 +00:00
# Disable gradients
torch . set_grad_enabled ( False )
2022-10-28 17:24:00 +00:00
2023-01-31 17:53:56 +00:00
def get_model (
2023-05-23 18:40:39 +00:00
model_id : str ,
revision : Optional [ str ] ,
sharded : bool ,
quantize : Optional [ str ] ,
trust_remote_code : bool ,
2023-01-31 17:53:56 +00:00
) - > Model :
2023-03-06 13:39:36 +00:00
if " facebook/galactica " in model_id :
2023-06-08 12:51:52 +00:00
return GalacticaSharded (
model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
)
2023-02-14 12:02:16 +00:00
2023-05-15 08:35:20 +00:00
if model_id . startswith ( " bigcode/ " ) :
2023-06-08 12:51:52 +00:00
if FLASH_ATTENTION :
2023-05-23 18:40:39 +00:00
return FlashSantacoderSharded (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-06-08 12:51:52 +00:00
elif sharded :
raise NotImplementedError (
FLASH_ATT_ERROR_MESSAGE . format ( " Sharded Santacoder " )
)
2023-04-03 17:06:42 +00:00
else :
2023-06-08 12:51:52 +00:00
return SantaCoder (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-02-14 12:02:16 +00:00
2023-06-01 17:49:13 +00:00
config_dict , _ = PretrainedConfig . get_config_dict (
model_id , revision = revision , trust_remote_code = trust_remote_code
)
2023-06-01 10:07:41 +00:00
model_type = config_dict [ " model_type " ]
2023-01-31 17:53:56 +00:00
2023-05-15 08:35:20 +00:00
if model_type == " gpt_bigcode " :
2023-06-08 12:51:52 +00:00
if FLASH_ATTENTION :
2023-05-23 18:40:39 +00:00
return FlashSantacoderSharded (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-06-08 12:51:52 +00:00
elif sharded :
raise NotImplementedError (
FLASH_ATT_ERROR_MESSAGE . format ( " Sharded Santacoder " )
)
2023-05-15 08:35:20 +00:00
else :
2023-06-08 12:51:52 +00:00
return SantaCoder (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-05-15 08:35:20 +00:00
2023-03-27 07:23:22 +00:00
if model_type == " bloom " :
2023-06-08 12:51:52 +00:00
return BLOOMSharded (
model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
)
elif model_type == " gpt_neox " :
if FLASH_ATTENTION :
return FlashNeoXSharded (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
elif sharded :
return GPTNeoxSharded (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-01-31 17:53:56 +00:00
else :
2023-06-08 12:51:52 +00:00
return CausalLM (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-02-14 12:02:16 +00:00
2023-06-08 12:51:52 +00:00
elif model_type == " llama " :
if FLASH_ATTENTION :
return FlashLlama (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-06-08 12:51:52 +00:00
elif sharded :
raise NotImplementedError ( FLASH_ATT_ERROR_MESSAGE . format ( " Sharded Llama " ) )
2022-10-28 17:24:00 +00:00
else :
2023-06-08 12:51:52 +00:00
return CausalLM (
2023-05-23 18:40:39 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-02-14 12:02:16 +00:00
2023-05-30 16:25:19 +00:00
if model_type in [ " RefinedWeb " , " RefinedWebModel " ] :
if sharded :
if FLASH_ATTENTION :
2023-06-01 10:07:41 +00:00
if config_dict . get ( " alibi " , False ) or (
model_type == " RefinedWebModel "
and config_dict . get ( " multi_query " , True )
2023-05-30 16:25:19 +00:00
) :
raise NotImplementedError ( " sharded is not supported for this model " )
return FlashRWSharded (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
raise NotImplementedError (
FLASH_ATT_ERROR_MESSAGE . format ( f " Sharded RefinedWeb " )
)
else :
2023-06-01 10:07:41 +00:00
if FLASH_ATTENTION and not config_dict . get ( " alibi " , False ) :
2023-06-08 12:51:52 +00:00
return FlashRWSharded (
2023-05-30 16:25:19 +00:00
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
else :
return RW (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-06-08 12:51:52 +00:00
elif model_type == " opt " :
return OPTSharded (
model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
)
2023-04-11 17:16:41 +00:00
2023-06-08 12:51:52 +00:00
elif model_type == " t5 " :
2023-02-07 17:25:17 +00:00
if sharded :
2023-05-23 18:40:39 +00:00
return T5Sharded (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-02-07 17:25:17 +00:00
else :
2023-05-23 18:40:39 +00:00
return Seq2SeqLM (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-02-14 12:02:16 +00:00
if sharded :
raise ValueError ( " sharded is not supported for AutoModel " )
2023-06-12 15:57:32 +00:00
if quantize == " gptq " :
raise ValueError ( " gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID` " )
2023-03-27 07:23:22 +00:00
if model_type in modeling_auto . MODEL_FOR_CAUSAL_LM_MAPPING_NAMES :
2023-05-23 18:40:39 +00:00
return CausalLM (
model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
)
2023-03-27 07:23:22 +00:00
if model_type in modeling_auto . MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES :
2023-05-23 18:40:39 +00:00
return Seq2SeqLM (
model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
)
2023-06-01 10:07:41 +00:00
auto_map = config_dict . get ( " auto_map " , None )
2023-05-23 18:40:39 +00:00
if trust_remote_code and auto_map is not None :
if " AutoModelForCausalLM " in auto_map . keys ( ) :
return CausalLM (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-05-26 10:31:47 +00:00
if " AutoModelForSeq2SeqLM " in auto_map . keys ( ) :
2023-05-23 18:40:39 +00:00
return Seq2SeqLM (
model_id ,
revision ,
quantize = quantize ,
trust_remote_code = trust_remote_code ,
)
2023-03-27 07:23:22 +00:00
raise ValueError ( f " Unsupported model type { model_type } " )