import os import torch from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from typing import Optional from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.rw import RW from text_generation_server.models.opt import OPTSharded from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) __all__ = [ "Model", "BLOOMSharded", "CausalLM", "FlashCausalLM", "GalacticaSharded", "Seq2SeqLM", "SantaCoder", "OPTSharded", "T5Sharded", "get_model", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_llama import ( FlashLlama, ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) from text_generation_server.models.idefics import IDEFICSSharded except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) __all__.append(IDEFICSSharded) def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], dtype: Optional[str], trust_remote_code: bool, ) -> Model: if dtype is None: dtype = torch.float16 elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") if "facebook/galactica" in model_id: return GalacticaSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) if model_id.startswith("bigcode/"): if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: return SantaCoder( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) model_type = config_dict["model_type"] if model_type == "gpt_bigcode": if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: return SantaCoder( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type == "bloom": return BLOOMSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "mpt": return MPTSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "gpt_neox": if FLASH_ATTENTION: return FlashNeoXSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: return GPTNeoxSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) else: return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "llama" or model_type == "baichuan": if FLASH_ATTENTION: return FlashLlama( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if sharded: if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): return FlashRWSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) else: return RW( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "opt": return OPTSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "t5": return T5Sharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "idefics": if FLASH_ATTENTION: return IDEFICSSharded( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if sharded: raise ValueError("sharded is not supported for AutoModel") if quantize == "gptq": raise ValueError( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) if quantize == "awq": raise ValueError("awq quantization is not supported for AutoModel") elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): raise ValueError("4bit quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): return Seq2SeqLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) raise ValueError(f"Unsupported model type {model_type}")