revert some changes

This commit is contained in:
OlivierDehaene 2023-06-05 18:54:23 +02:00
parent afdfe43346
commit 92a74ea036
2 changed files with 24 additions and 28 deletions

View File

@ -140,25 +140,22 @@ class FlashLlamaAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if prefill: if prefill:
# Copy to layer past # Copy to layer past
layer_past[past_present_indices] = kv layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, qkv[:, 0],
torch.select(kv, dim=1, index=0), qkv[:, 1],
torch.select(kv, dim=1, index=1), qkv[:, 2],
attn_output, attn_output,
start_seq, start_seq,
end_seq, end_seq,
@ -176,16 +173,17 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), layer_past[:, 0],
torch.select(kv, dim=1, index=1), layer_past[:, 1],
attn_output, attn_output,
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
@ -386,7 +384,7 @@ class FlashLlamaModel(torch.nn.Module):
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i), past_key_values[:, i],
past_present_indices, past_present_indices,
prefill, prefill,
) )

View File

@ -125,25 +125,22 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split([1, 2], dim=1)
query = query.view(-1, self.num_heads, self.head_size)
# Inplace rotary # Inplace rotary
self.rotary_emb(query, cos, sin) self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if prefill: if prefill:
# Copy to layer past # Copy to layer past
layer_past[past_present_indices] = kv layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, qkv[:, 0],
torch.select(kv, dim=1, index=0), qkv[:, 1],
torch.select(kv, dim=1, index=1), qkv[:, 2],
attn_output, attn_output,
start_seq, start_seq,
end_seq, end_seq,
@ -161,16 +158,17 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = kv layer_past[past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
query, query,
torch.select(kv, dim=1, index=0), layer_past[:, 0],
torch.select(kv, dim=1, index=1), layer_past[:, 1],
attn_output, attn_output,
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
@ -395,7 +393,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
start_seq_q, start_seq_q,
end_seq_q, end_seq_q,
max_s, max_s,
torch.select(past_key_values, dim=1, index=i), past_key_values[:, i],
past_present_indices, past_present_indices,
prefill, prefill,
) )