mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Refactor model instantiation for Mistral type
Modified the instantiation logic for the 'mistral' model type to support V100 GPU architectures. If FLASH_ATTENTION is not available, the code falls back to the generic CausalLM, ensuring functionality regardless of the underlying hardware. This change circumvents the incompatibility issue with the flash-attention package on V100 GPUs.
This commit is contained in:
parent
2b9b04008b
commit
60679fa297
@ -44,7 +44,7 @@ __all__ = [
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
FLASH_ATTENTION = True
|
||||
try:
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
@ -248,17 +248,8 @@ def get_model(
|
||||
)
|
||||
|
||||
if model_type == "mistral":
|
||||
# if MISTRAL:
|
||||
# return FlashMistral(
|
||||
# model_id,
|
||||
# revision,
|
||||
# quantize=quantize,
|
||||
# dtype=dtype,
|
||||
# trust_remote_code=trust_remote_code,
|
||||
# )
|
||||
# raise NotImplementedError("Mistral model requires flash attention v2")
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
Loading…
Reference in New Issue
Block a user