mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
More import sanitation.
This commit is contained in:
parent
4ff509948a
commit
51348713d6
@ -18,8 +18,6 @@ from text_generation_server.models.galactica import GalacticaSharded
|
|||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
# in PyTorch 1.12 and later.
|
# in PyTorch 1.12 and later.
|
||||||
@ -42,7 +40,6 @@ __all__ = [
|
|||||||
"OPTSharded",
|
"OPTSharded",
|
||||||
"T5Sharded",
|
"T5Sharded",
|
||||||
"get_model",
|
"get_model",
|
||||||
"IDEFICSSharded",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
@ -57,6 +54,7 @@ try:
|
|||||||
from text_generation_server.models.flash_santacoder import (
|
from text_generation_server.models.flash_santacoder import (
|
||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
@ -67,6 +65,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashRWSharded)
|
__all__.append(FlashRWSharded)
|
||||||
__all__.append(FlashSantacoderSharded)
|
__all__.append(FlashSantacoderSharded)
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
|
__all__.append(IDEFICSSharded)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -252,13 +251,16 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == "idefics":
|
elif model_type == "idefics":
|
||||||
return IDEFICSSharded(
|
if FLASH_ATTENTION:
|
||||||
model_id,
|
return IDEFICSSharded(
|
||||||
revision,
|
model_id,
|
||||||
quantize=quantize,
|
revision,
|
||||||
dtype=dtype,
|
quantize=quantize,
|
||||||
trust_remote_code=trust_remote_code,
|
dtype=dtype,
|
||||||
)
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
@ -9,13 +9,13 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.models import IdeficsCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||||
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
|
from text_generation_server.models.custom_modeling.idefics_processing import IdeficsProcessor
|
||||||
from transformers import LlamaTokenizerFast
|
from transformers import LlamaTokenizerFast
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
IdeficsForVisionText2Text,
|
IdeficsForVisionText2Text,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
|
Loading…
Reference in New Issue
Block a user