From 60679fa2970b9352832ff7de8cbe49882a664306 Mon Sep 17 00:00:00 2001 From: xihajun Date: Thu, 23 Nov 2023 17:45:38 +0000 Subject: [PATCH] Refactor model instantiation for Mistral type Modified the instantiation logic for the 'mistral' model type to support V100 GPU architectures. If FLASH_ATTENTION is not available, the code falls back to the generic CausalLM, ensuring functionality regardless of the underlying hardware. This change circumvents the incompatibility issue with the flash-attention package on V100 GPUs. --- server/text_generation_server/models/__init__.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 57d26a42..5e6e8856 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -44,7 +44,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = False +FLASH_ATTENTION = True try: from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -248,17 +248,8 @@ def get_model( ) if model_type == "mistral": - # if MISTRAL: - # return FlashMistral( - # model_id, - # revision, - # quantize=quantize, - # dtype=dtype, - # trust_remote_code=trust_remote_code, - # ) - # raise NotImplementedError("Mistral model requires flash attention v2") if FLASH_ATTENTION: - return FlashLlama( + return FlashMistral( model_id, revision, quantize=quantize,