mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
More import sanitation.
This commit is contained in:
parent
4ff509948a
commit
51348713d6
@ -18,8 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
@ -42,7 +40,6 @@ __all__ = [
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
"IDEFICSSharded",
|
||||
]
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
@ -57,6 +54,7 @@ try:
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
@ -67,6 +65,7 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
__all__.append(IDEFICSSharded)
|
||||
|
||||
|
||||
def get_model(
|
||||
@ -252,13 +251,16 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "idefics":
|
||||
return IDEFICSSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
|
@ -9,13 +9,13 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from text_generation_server.models import IdeficsCausalLM
|
||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
|
||||
from transformers import LlamaTokenizerFast
|
||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||
IdeficsForVisionText2Text,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
|
Loading…
Reference in New Issue
Block a user