This commit is contained in:
OlivierDehaene 2023-03-23 13:33:32 +01:00
parent 19a04f22dd
commit e5e22993e7
2 changed files with 122 additions and 99 deletions

View File

@ -338,7 +338,6 @@ class FlashNeoX(Model):
# Create final next batch tensors # Create final next batch tensors
device = out.device device = out.device
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_position_ids = torch.tensor( next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32, device=device next_batch_position_ids, dtype=torch.int32, device=device
) )
@ -346,8 +345,10 @@ class FlashNeoX(Model):
next_batch_cu_seqlens, dtype=torch.int32, device=device next_batch_cu_seqlens, dtype=torch.int32, device=device
) )
if len(next_batch_keep_indices) > 1: if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0)
next_batch_past_key_values = torch.concat(next_batch_past_key_values) next_batch_past_key_values = torch.concat(next_batch_past_key_values)
else: else:
next_batch_input_ids = next_batch_input_ids[0]
next_batch_past_key_values = next_batch_past_key_values[0] next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch( next_batch = FlashNeoXBatch(

View File

@ -174,31 +174,23 @@ class PositionRotaryEmbedding(RotaryEmbedding):
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int): def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s) self._update_cos_sin_cache(dtype, position_ids.device, max_s)
q1, q2, k1, k2, cos, sin = _prepare_rotary( cos = torch.index_select(self._cos_cached, 0, position_ids)
qkv, self._cos_cached, self._sin_cached, position_ids sin = torch.index_select(self._sin_cached, 0, position_ids)
) return cos.unsqueeze(1), sin.unsqueeze(1)
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
@torch.jit.script
def _prepare_rotary(
qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
):
cos = torch.index_select(cos, 0, position_ids)
sin = torch.index_select(sin, 0, position_ids)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim] q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim] k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
class FlashNeoxAttention(torch.nn.Module): class FlashNeoxAttention(torch.nn.Module):
@ -229,7 +221,7 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
) )
self.swap_dims = False self.swap_dims = True
def _swap_dims(self): def _swap_dims(self):
self.query_key_value.weight = torch.nn.Parameter( self.query_key_value.weight = torch.nn.Parameter(
@ -244,17 +236,25 @@ class FlashNeoxAttention(torch.nn.Module):
.permute(1, 0, 2) .permute(1, 0, 2)
.reshape(-1) .reshape(-1)
) )
self.swap_dims = True self.swap_dims = False
def forward( def forward(
self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
): ):
if not self.swap_dims: if self.swap_dims:
self._swap_dims() self._swap_dims()
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, position_ids, max_s) qkv_rot = self.rotary_emb(qkv, cos, sin)
if layer_past_present_indices is None: if layer_past_present_indices is None:
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv_rot[:, 1:]
@ -372,7 +372,8 @@ class FlashNeoXLayer(nn.Module):
self, self,
hidden_states, hidden_states,
residual, residual,
position_ids, cos,
sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
layer_past, layer_past,
@ -399,7 +400,14 @@ class FlashNeoXLayer(nn.Module):
) )
attn_output = self.attention( attn_output = self.attention(
ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q ln1_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
) )
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -442,7 +450,14 @@ class FlashNeoXLayer(nn.Module):
) )
hidden_states = self.attention( hidden_states = self.attention(
hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
) )
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
@ -543,19 +558,26 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q = None cu_seqlens_q = None
else: else:
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device) cu_seqlens_q = torch.arange(
len(cu_seqlens), dtype=torch.int32, device=hidden_states.device
)
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
position_ids, cos,
sin,
cu_seqlens, cu_seqlens,
max_s, max_s,
past_key_values[i], past_key_values[i],
layer_past_present_indices, layer_past_present_indices,
cu_seqlens_q cu_seqlens_q,
) )
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)