fix: warn window_size_left when using flash attn 1

This commit is contained in:
drbh 2024-07-30 20:24:48 +00:00
parent 4b1005c7e1
commit 5123925101

View File

@ -3,6 +3,7 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
import warnings
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
@ -289,9 +290,11 @@ else:
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( warnings.warn(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2. It will be ignored.",
UserWarning,
) )
if softcap is not None: if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2") raise NotImplementedError("softcap is only available with flash attn v2")
@ -338,3 +341,6 @@ else:
0, 0,
None, None,
) )
SUPPORTS_WINDOWING = True