fix: max_past default value must be -1, not 0

This commit is contained in:
OlivierDehaene 2023-12-15 00:10:58 +01:00
parent 9b78a6eee3
commit f75bbbcc63
3 changed files with 5 additions and 2 deletions

View File

@ -149,7 +149,7 @@ class MistralAttention(torch.nn.Module):
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size

View File

@ -204,7 +204,7 @@ class MixtralAttention(torch.nn.Module):
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size

View File

@ -72,6 +72,9 @@ def attention(
softmax_scale,
window_size_left=-1,
):
if window_size_left == 0:
raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,