diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 19615a60..6d53a72b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -858,6 +858,15 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=GPTNeoXConfig, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: return CausalLM( model_id=model_id,