This commit is contained in:
OlivierDehaene 2023-03-23 13:33:32 +01:00
parent 19a04f22dd
commit e5e22993e7
2 changed files with 122 additions and 99 deletions

View File

@ -338,7 +338,6 @@ class FlashNeoX(Model):
# Create final next batch tensors # Create final next batch tensors
device = out.device device = out.device
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_position_ids = torch.tensor( next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32, device=device next_batch_position_ids, dtype=torch.int32, device=device
) )
@ -346,8 +345,10 @@ class FlashNeoX(Model):
next_batch_cu_seqlens, dtype=torch.int32, device=device next_batch_cu_seqlens, dtype=torch.int32, device=device
) )
if len(next_batch_keep_indices) > 1: if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_past_key_values = torch.concat(next_batch_past_key_values) next_batch_past_key_values = torch.concat(next_batch_past_key_values)
else: else:
next_batch_input_ids = next_batch_input_ids[0]
next_batch_past_key_values = next_batch_past_key_values[0] next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch( next_batch = FlashNeoXBatch(

View File

@ -18,13 +18,13 @@ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_
class TensorParallelColumnLinear(nn.Linear): class TensorParallelColumnLinear(nn.Linear):
def __init__( def __init__(
self, self,
in_features, in_features,
out_features, out_features,
process_group: torch.distributed.ProcessGroup, process_group: torch.distributed.ProcessGroup,
bias=True, bias=True,
device=None, device=None,
dtype=None, dtype=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -49,13 +49,13 @@ class TensorParallelColumnLinear(nn.Linear):
class TensorParallelRowLinear(nn.Linear): class TensorParallelRowLinear(nn.Linear):
def __init__( def __init__(
self, self,
in_features, in_features,
out_features, out_features,
process_group: torch.distributed.ProcessGroup, process_group: torch.distributed.ProcessGroup,
bias=True, bias=True,
device=None, device=None,
dtype=None, dtype=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
@ -83,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear):
class TensorParallelEmbedding(nn.Embedding): class TensorParallelEmbedding(nn.Embedding):
def __init__( def __init__(
self, self,
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
process_group: torch.distributed.ProcessGroup, process_group: torch.distributed.ProcessGroup,
padding_idx=None, padding_idx=None,
max_norm=None, max_norm=None,
norm_type=2.0, norm_type=2.0,
scale_grad_by_freq=False, scale_grad_by_freq=False,
sparse=False, sparse=False,
_weight=None, _weight=None,
device=None, device=None,
dtype=None, dtype=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
@ -125,7 +125,7 @@ class TensorParallelEmbedding(nn.Embedding):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check # Sanity check
if torch.any( if torch.any(
torch.logical_or(0 > input, input >= self.original_num_embeddings) torch.logical_or(0 > input, input >= self.original_num_embeddings)
): ):
raise IndexError( raise IndexError(
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
@ -148,9 +148,9 @@ class PositionRotaryEmbedding(RotaryEmbedding):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
@ -162,11 +162,11 @@ class PositionRotaryEmbedding(RotaryEmbedding):
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
else: else:
power = ( power = (
torch.arange( torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device seqlen, dtype=self.scale.dtype, device=self.scale.device
) )
- seqlen // 2 - seqlen // 2
) / self.scale_base ) / self.scale_base
scale = self.scale.to(device=power.device) ** power.unsqueeze(1) scale = self.scale.to(device=power.device) ** power.unsqueeze(1)
# We want the multiplication by scale to happen in fp32 # We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
@ -174,36 +174,28 @@ class PositionRotaryEmbedding(RotaryEmbedding):
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int): def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s) self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
q1, q2, k1, k2, cos, sin = _prepare_rotary(
qkv, self._cos_cached, self._sin_cached, position_ids
)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv return qkv
@torch.jit.script
def _prepare_rotary(
qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
):
cos = torch.index_select(cos, 0, position_ids)
sin = torch.index_select(sin, 0, position_ids)
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim: 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim: 2 * rotary_dim]
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1)
class FlashNeoxAttention(torch.nn.Module): class FlashNeoxAttention(torch.nn.Module):
def __init__( def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@ -229,7 +221,7 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
) )
self.swap_dims = False self.swap_dims = True
def _swap_dims(self): def _swap_dims(self):
self.query_key_value.weight = torch.nn.Parameter( self.query_key_value.weight = torch.nn.Parameter(
@ -244,17 +236,25 @@ class FlashNeoxAttention(torch.nn.Module):
.permute(1, 0, 2) .permute(1, 0, 2)
.reshape(-1) .reshape(-1)
) )
self.swap_dims = True self.swap_dims = False
def forward( def forward(
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
): ):
if not self.swap_dims: if self.swap_dims:
self._swap_dims() self._swap_dims()
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, position_ids, max_s) qkv_rot = self.rotary_emb(qkv, cos, sin)
if layer_past_present_indices is None: if layer_past_present_indices is None:
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv_rot[:, 1:]
@ -348,16 +348,16 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module): class FlashNeoXLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, num_heads,
act, act,
hidden_size, hidden_size,
intermediate_size, intermediate_size,
rotary_pct, rotary_pct,
rotary_emb_base, rotary_emb_base,
layer_norm_eps, layer_norm_eps,
use_parallel_residual, use_parallel_residual,
process_group=None, process_group=None,
): ):
super().__init__() super().__init__()
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
@ -369,15 +369,16 @@ class FlashNeoXLayer(nn.Module):
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
position_ids, cos,
cu_seqlens, sin,
max_s, cu_seqlens,
layer_past, max_s,
layer_past_present_indices, layer_past,
cu_seqlens_q, layer_past_present_indices,
cu_seqlens_q,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -399,7 +400,14 @@ class FlashNeoXLayer(nn.Module):
) )
attn_output = self.attention( attn_output = self.attention(
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q ln1_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
) )
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -442,7 +450,14 @@ class FlashNeoXLayer(nn.Module):
) )
hidden_states = self.attention( hidden_states = self.attention(
hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
) )
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -520,12 +535,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].attention.num_heads
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values=None,
): ):
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
@ -543,19 +558,26 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q = None cu_seqlens_q = None
else: else:
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device) cu_seqlens_q = torch.arange(
len(cu_seqlens), dtype=torch.int32, device=hidden_states.device
)
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
position_ids, cos,
sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values[i], past_key_values[i],
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q cu_seqlens_q,
) )
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -586,12 +608,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values=None, past_key_values=None,
): ):
hidden_states, present = self.gpt_neox( hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values input_ids, position_ids, cu_seqlens, max_s, past_key_values