mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Changing return everywhere.
This commit is contained in:
parent
a26e57f9f3
commit
8fa8cda660
@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -229,7 +229,7 @@ class MistralAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
|
@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
Loading…
Reference in New Issue
Block a user