mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Test on mistral with CausalLM
This commit is contained in:
parent
3c02262f29
commit
f1004498d0
@ -44,7 +44,7 @@ __all__ = [
|
|||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = False
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
@ -248,15 +248,31 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mistral":
|
if model_type == "mistral":
|
||||||
if MISTRAL:
|
# if MISTRAL:
|
||||||
return FlashMistral(
|
# return FlashMistral(
|
||||||
|
# model_id,
|
||||||
|
# revision,
|
||||||
|
# quantize=quantize,
|
||||||
|
# dtype=dtype,
|
||||||
|
# trust_remote_code=trust_remote_code,
|
||||||
|
# )
|
||||||
|
# raise NotImplementedError("Mistral model requires flash attention v2")
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return FlashLlama(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
|
Loading…
Reference in New Issue
Block a user