More import sanitation.

This commit is contained in:
Nicolas Patry 2023-08-15 15:09:19 +00:00
parent 4ff509948a
commit 51348713d6
2 changed files with 13 additions and 11 deletions

View File

@ -18,8 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.idefics import IDEFICSSharded
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -42,7 +40,6 @@ __all__ = [
"OPTSharded", "OPTSharded",
"T5Sharded", "T5Sharded",
"get_model", "get_model",
"IDEFICSSharded",
] ]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
@ -57,6 +54,7 @@ try:
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
) )
from text_generation_server.models.idefics import IDEFICSSharded
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -67,6 +65,7 @@ if FLASH_ATTENTION:
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(IDEFICSSharded)
def get_model( def get_model(
@ -252,13 +251,16 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "idefics": elif model_type == "idefics":
return IDEFICSSharded( if FLASH_ATTENTION:
model_id, return IDEFICSSharded(
revision, model_id,
quantize=quantize, revision,
dtype=dtype, quantize=quantize,
trust_remote_code=trust_remote_code, dtype=dtype,
) trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")

View File

@ -9,13 +9,13 @@ from transformers import (
AutoProcessor, AutoProcessor,
) )
from text_generation_server.models import IdeficsCausalLM
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, IdeficsForVisionText2Text,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,