Test on mistral with CausalLM

This commit is contained in:
xihajun 2023-11-23 16:34:06 +00:00 committed by GitHub
parent 3c02262f29
commit f1004498d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -44,7 +44,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = False
try: try:
from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
@ -248,15 +248,31 @@ def get_model(
) )
if model_type == "mistral": if model_type == "mistral":
if MISTRAL: # if MISTRAL:
return FlashMistral( # return FlashMistral(
# model_id,
# revision,
# quantize=quantize,
# dtype=dtype,
# trust_remote_code=trust_remote_code,
# )
# raise NotImplementedError("Mistral model requires flash attention v2")
if FLASH_ATTENTION:
return FlashLlama(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mistral model requires flash attention v2")
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(