revert some changes

This commit is contained in:
OlivierDehaene 2023-06-05 18:54:23 +02:00
parent afdfe43346
commit 92a74ea036
2 changed files with 24 additions and 28 deletions

View File

@ -140,25 +140,22 @@ class FlashLlamaAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = kv
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
@ -176,16 +173,17 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
layer_past[:, 0],
layer_past[:, 1],
attn_output,
start_seq_q,
end_seq_q,
@ -386,7 +384,7 @@ class FlashLlamaModel(torch.nn.Module):
start_seq_q,
end_seq_q,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_key_values[:, i],
past_present_indices,
prefill,
)

View File

@ -125,25 +125,22 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[past_present_indices] = kv
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
@ -161,16 +158,17 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
layer_past[:, 0],
layer_past[:, 1],
attn_output,
start_seq_q,
end_seq_q,
@ -395,7 +393,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
start_seq_q,
end_seq_q,
max_s,
torch.select(past_key_values, dim=1, index=i),
past_key_values[:, i],
past_present_indices,
prefill,
)