mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
revert some changes
This commit is contained in:
parent
afdfe43346
commit
92a74ea036
@ -140,25 +140,22 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
|
||||
query, kv = qkv.split([1, 2], dim=1)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
@ -176,16 +173,17 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
@ -386,7 +384,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
torch.select(past_key_values, dim=1, index=i),
|
||||
past_key_values[:, i],
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
@ -125,25 +125,22 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
|
||||
query, kv = qkv.split([1, 2], dim=1)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if prefill:
|
||||
# Copy to layer past
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
start_seq,
|
||||
end_seq,
|
||||
@ -161,16 +158,17 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[past_present_indices] = kv
|
||||
layer_past[past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
layer_past[:, 0],
|
||||
layer_past[:, 1],
|
||||
attn_output,
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
@ -395,7 +393,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
start_seq_q,
|
||||
end_seq_q,
|
||||
max_s,
|
||||
torch.select(past_key_values, dim=1, index=i),
|
||||
past_key_values[:, i],
|
||||
past_present_indices,
|
||||
prefill,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user