More import sanitation.

This commit is contained in:
Nicolas Patry 2023-08-15 15:09:19 +00:00
parent 4ff509948a
commit 51348713d6
2 changed files with 13 additions and 11 deletions

View File

@ -18,8 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.idefics import IDEFICSSharded
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@ -42,7 +40,6 @@ __all__ = [
"OPTSharded",
"T5Sharded",
"get_model",
"IDEFICSSharded",
]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
@ -57,6 +54,7 @@ try:
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
from text_generation_server.models.idefics import IDEFICSSharded
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -67,6 +65,7 @@ if FLASH_ATTENTION:
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded)
def get_model(
@ -252,6 +251,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "idefics":
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
revision,
@ -259,6 +259,8 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if sharded:
raise ValueError("sharded is not supported for AutoModel")

View File

@ -9,13 +9,13 @@ from transformers import (
AutoProcessor,
)
from text_generation_server.models import IdeficsCausalLM
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,