Refactor model instantiation for Mistral type

Modified the instantiation logic for the 'mistral' model type to support V100 GPU architectures. If FLASH_ATTENTION is not available, the code falls back to the generic CausalLM, ensuring functionality regardless of the underlying hardware. This change circumvents the incompatibility issue with the flash-attention package on V100 GPUs.
This commit is contained in:
xihajun 2023-11-23 17:45:38 +00:00 committed by GitHub
parent 2b9b04008b
commit 60679fa297
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,7 +44,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -248,17 +248,8 @@ def get_model(
)
if model_type == "mistral":
# if MISTRAL:
# return FlashMistral(
# model_id,
# revision,
# quantize=quantize,
# dtype=dtype,
# trust_remote_code=trust_remote_code,
# )
# raise NotImplementedError("Mistral model requires flash attention v2")
if FLASH_ATTENTION:
return FlashLlama(
return FlashMistral(
model_id,
revision,
quantize=quantize,