pre-compute

This commit is contained in:
OlivierDehaene 2023-03-23 13:10:31 +01:00
parent cdc70f4c23
commit 19a04f22dd

View File

@ -18,13 +18,13 @@ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
class TensorParallelColumnLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
@ -49,13 +49,13 @@ class TensorParallelColumnLinear(nn.Linear):
class TensorParallelRowLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
@ -83,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear):
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
@ -125,7 +125,7 @@ class TensorParallelEmbedding(nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check
if torch.any(
torch.logical_or(0 > input, input >= self.original_num_embeddings)
torch.logical_or(0 > input, input >= self.original_num_embeddings)
):
raise IndexError(
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
@ -148,9 +148,9 @@ class PositionRotaryEmbedding(RotaryEmbedding):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
@ -162,11 +162,11 @@ class PositionRotaryEmbedding(RotaryEmbedding):
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
@ -187,23 +187,23 @@ class PositionRotaryEmbedding(RotaryEmbedding):
@torch.jit.script
def _prepare_rotary(
qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
):
cos = torch.index_select(cos, 0, position_ids)
sin = torch.index_select(sin, 0, position_ids)
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
q2 = qkv[:, 0, :, rotary_dim: 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
k2 = qkv[:, 1, :, rotary_dim: 2 * rotary_dim]
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1)
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
):
super().__init__()
self.num_heads = num_heads
@ -247,7 +247,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.swap_dims = True
def forward(
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
):
if not self.swap_dims:
self._swap_dims()
@ -256,7 +256,7 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, position_ids, max_s)
if prefill:
if layer_past_present_indices is None:
layer_past[...] = qkv_rot[:, 1:]
attn_output = torch.empty_like(qkv[:, 0])
@ -279,7 +279,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
else:
query = qkv_rot[:, 0]
layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:]
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
attn_output = torch.empty_like(query)
flash_attn_cuda.fwd(
@ -287,9 +287,9 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[:, 0],
layer_past[:, 1],
attn_output,
torch.arange(len(cu_seqlens), dtype=torch.int32).to(query.device),
cu_seqlens_q,
cu_seqlens,
torch.tensor(1, dtype=torch.int32).to(query.device),
1,
max_s,
0.0,
self.softmax_scale,
@ -348,16 +348,16 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
super().__init__()
self.use_parallel_residual = use_parallel_residual
@ -369,14 +369,15 @@ class FlashNeoXLayer(nn.Module):
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
def forward(
self,
hidden_states,
residual,
position_ids,
cu_seqlens,
max_s,
layer_past,
prefill,
self,
hidden_states,
residual,
position_ids,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.use_parallel_residual:
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -398,7 +399,7 @@ class FlashNeoXLayer(nn.Module):
)
attn_output = self.attention(
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
)
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -441,7 +442,7 @@ class FlashNeoXLayer(nn.Module):
)
hidden_states = self.attention(
hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill
hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q
)
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -519,16 +520,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.num_heads = self.layers[0].attention.num_heads
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states = self.embed_in(input_ids)
prefill = False
if past_key_values is None:
past_key_values = hidden_states.new_empty(
(
@ -539,7 +539,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size,
)
)
prefill = True
layer_past_present_indices = None
cu_seqlens_q = None
else:
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device)
residual = None
for i, layer in enumerate(self.layers):
@ -550,7 +554,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens,
max_s,
past_key_values[i],
prefill,
layer_past_present_indices,
cu_seqlens_q
)
hidden_states = self.final_layer_norm(hidden_states)
@ -581,12 +586,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
)
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values