mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
revert some changes
This commit is contained in:
parent
afdfe43346
commit
92a74ea036
@ -140,25 +140,22 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
query, kv = qkv.split([1, 2], dim=1)
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
qkv[:, 0],
|
||||||
torch.select(kv, dim=1, index=0),
|
qkv[:, 1],
|
||||||
torch.select(kv, dim=1, index=1),
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
@ -176,16 +173,17 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
layer_past[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
layer_past[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
@ -386,7 +384,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
past_key_values[:, i],
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
@ -125,25 +125,22 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
query, kv = qkv.split([1, 2], dim=1)
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
qkv[:, 0],
|
||||||
torch.select(kv, dim=1, index=0),
|
qkv[:, 1],
|
||||||
torch.select(kv, dim=1, index=1),
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq,
|
start_seq,
|
||||||
end_seq,
|
end_seq,
|
||||||
@ -161,16 +158,17 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[past_present_indices] = kv
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
layer_past[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
layer_past[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
@ -395,7 +393,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
start_seq_q,
|
start_seq_q,
|
||||||
end_seq_q,
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
torch.select(past_key_values, dim=1, index=i),
|
past_key_values[:, i],
|
||||||
past_present_indices,
|
past_present_indices,
|
||||||
prefill,
|
prefill,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user