This commit is contained in:
Felix Marty 2024-06-13 07:31:11 +00:00
parent b3e9a13e27
commit 8f1de30b0f

View File

@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module): class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_idx):
super().__init__() super().__init__()
self.self_attn = FlashGPT2Attention( self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights prefix=f"{prefix}.attn", config=config, weights=weights
@ -313,6 +313,7 @@ class FlashGPT2Layer(nn.Module):
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.layer_idx = layer_idx
def forward( def forward(
self, self,
@ -324,10 +325,14 @@ class FlashGPT2Layer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
step,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
if self.layer_idx < 5:
torch.save(hidden_states, f"hidden_states_bef_attn_step{step}_layer{self.layer_idx}.pt")
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
hidden_states, hidden_states,
@ -339,6 +344,9 @@ class FlashGPT2Layer(nn.Module):
max_s, max_s,
) )
if self.layer_idx < 5:
torch.save(attn_output, f"attn_output_step{step}_layer{self.layer_idx}.pt")
hidden_states = attn_output + residual hidden_states = attn_output + residual
residual = hidden_states residual = hidden_states
@ -346,6 +354,9 @@ class FlashGPT2Layer(nn.Module):
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
if self.layer_idx < 5:
torch.save(mlp_output, f"mlp_output_step{step}_layer{self.layer_idx}.pt")
return residual + mlp_output, residual return residual + mlp_output, residual
@ -364,6 +375,7 @@ class FlashGPT2Model(torch.nn.Module):
), ),
config=config, config=config,
weights=weights, weights=weights,
layer_idx=layer_id
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
@ -379,6 +391,7 @@ class FlashGPT2Model(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.step = 0
def forward( def forward(
self, self,
@ -406,7 +419,9 @@ class FlashGPT2Model(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
self.step,
) )
self.step += 1
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)