diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..57d26a42 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -44,7 +44,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = True +FLASH_ATTENTION = False try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -248,15 +248,31 @@ def get_model( ) if model_type == "mistral": - if MISTRAL: - return FlashMistral( + # if MISTRAL: + # return FlashMistral( + # model_id, + # revision, + # quantize=quantize, + # dtype=dtype, + # trust_remote_code=trust_remote_code, + # ) + # raise NotImplementedError("Mistral model requires flash attention v2") + if FLASH_ATTENTION: + return FlashLlama( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mistral model requires flash attention v2") if model_type == "opt": return OPTSharded(