mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
pre-compute
This commit is contained in:
parent
cdc70f4c23
commit
19a04f22dd
@ -247,7 +247,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.swap_dims = True
|
self.swap_dims = True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
|
||||||
):
|
):
|
||||||
if not self.swap_dims:
|
if not self.swap_dims:
|
||||||
self._swap_dims()
|
self._swap_dims()
|
||||||
@ -256,7 +256,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
qkv_rot = self.rotary_emb(qkv, position_ids, max_s)
|
qkv_rot = self.rotary_emb(qkv, position_ids, max_s)
|
||||||
|
|
||||||
if prefill:
|
if layer_past_present_indices is None:
|
||||||
layer_past[...] = qkv_rot[:, 1:]
|
layer_past[...] = qkv_rot[:, 1:]
|
||||||
|
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
@ -279,7 +279,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = qkv_rot[:, 0]
|
query = qkv_rot[:, 0]
|
||||||
layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:]
|
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
|
||||||
|
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
@ -287,9 +287,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
layer_past[:, 0],
|
layer_past[:, 0],
|
||||||
layer_past[:, 1],
|
layer_past[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
torch.arange(len(cu_seqlens), dtype=torch.int32).to(query.device),
|
cu_seqlens_q,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
torch.tensor(1, dtype=torch.int32).to(query.device),
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -376,7 +376,8 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
prefill,
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
@ -398,7 +399,7 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
@ -441,7 +442,7 @@ class FlashNeoXLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
|
hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
@ -528,7 +529,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
):
|
):
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
|
||||||
prefill = False
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
@ -539,7 +539,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
prefill = True
|
layer_past_present_indices = None
|
||||||
|
cu_seqlens_q = None
|
||||||
|
else:
|
||||||
|
layer_past_present_indices = cu_seqlens[1:] - 1
|
||||||
|
cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
@ -550,7 +554,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values[i],
|
past_key_values[i],
|
||||||
prefill,
|
layer_past_present_indices,
|
||||||
|
cu_seqlens_q
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
Loading…
Reference in New Issue
Block a user