diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index ddae0a96..6ca7b567 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -38,6 +38,7 @@ __all__ = [ ] from text_generation_server.models.globals import ATTENTION +VLM_BATCH_TYPES = set() FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = False @@ -155,6 +156,9 @@ if FLASH_ATTENTION: } +__all__.append(VLM_BATCH_TYPES) + + class ModelType(enum.Enum): DEEPSEEK_V2 = { "type": "deepseek_v2",