From 51348713d6d6495192d97aeac0d4340ec595689d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 15 Aug 2023 15:09:19 +0000 Subject: [PATCH] More import sanitation. --- .../text_generation_server/models/__init__.py | 22 ++++++++++--------- .../text_generation_server/models/idefics.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 518e1a15..932ab32e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,8 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM -from text_generation_server.models.idefics import IDEFICSSharded # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -42,7 +40,6 @@ __all__ = [ "OPTSharded", "T5Sharded", "get_model", - "IDEFICSSharded", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -57,6 +54,7 @@ try: from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) + from text_generation_server.models.idefics import IDEFICSSharded except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -67,6 +65,7 @@ if FLASH_ATTENTION: __all__.append(FlashRWSharded) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) + __all__.append(IDEFICSSharded) def get_model( @@ -252,13 +251,16 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == "idefics": - return IDEFICSSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) + if FLASH_ATTENTION: + return IDEFICSSharded( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if sharded: raise ValueError("sharded is not supported for AutoModel") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 07fea1f2..c4de21a7 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -9,13 +9,13 @@ from transformers import ( AutoProcessor, ) -from text_generation_server.models import IdeficsCausalLM from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor from transformers import LlamaTokenizerFast from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) +from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files,