From f1004498d0d1722ff1eff3b7b6a4604b1c830a4e Mon Sep 17 00:00:00 2001 From: xihajun Date: Thu, 23 Nov 2023 16:34:06 +0000 Subject: [PATCH] Test on mistral with CausalLM --- .../text_generation_server/models/__init__.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..57d26a42 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -44,7 +44,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = True +FLASH_ATTENTION = False try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -248,15 +248,31 @@ def get_model( ) if model_type == "mistral": - if MISTRAL: - return FlashMistral( + # if MISTRAL: + # return FlashMistral( + # model_id, + # revision, + # quantize=quantize, + # dtype=dtype, + # trust_remote_code=trust_remote_code, + # ) + # raise NotImplementedError("Mistral model requires flash attention v2") + if FLASH_ATTENTION: + return FlashLlama( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + return CausalLM( model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mistral model requires flash attention v2") if model_type == "opt": return OPTSharded(