Update server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

This commit is contained in:
fxmarty 2023-11-08 19:07:58 +09:00 committed by GitHub
parent 891fe74099
commit 0eea83be44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -43,7 +43,7 @@ if IS_CUDA_SYSTEM:
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_ROCM and not HAS_FLASH_ATTN_V2_ROCM:
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")