Changing return everywhere.

This commit is contained in:
Nicolas Patry 2024-07-01 12:08:59 +00:00
parent a26e57f9f3
commit 8fa8cda660
12 changed files with 13 additions and 13 deletions

View File

@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -229,7 +229,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],

View File

@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],