mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
debug
This commit is contained in:
parent
b3e9a13e27
commit
8f1de30b0f
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGPT2Layer(nn.Module):
|
class FlashGPT2Layer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGPT2Attention(
|
self.self_attn = FlashGPT2Attention(
|
||||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||||
@ -313,6 +313,7 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -324,10 +325,14 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
step,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
if self.layer_idx < 5:
|
||||||
|
torch.save(hidden_states, f"hidden_states_bef_attn_step{step}_layer{self.layer_idx}.pt")
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -339,6 +344,9 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.layer_idx < 5:
|
||||||
|
torch.save(attn_output, f"attn_output_step{step}_layer{self.layer_idx}.pt")
|
||||||
|
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -346,6 +354,9 @@ class FlashGPT2Layer(nn.Module):
|
|||||||
|
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
if self.layer_idx < 5:
|
||||||
|
torch.save(mlp_output, f"mlp_output_step{step}_layer{self.layer_idx}.pt")
|
||||||
|
|
||||||
return residual + mlp_output, residual
|
return residual + mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
@ -364,6 +375,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_idx=layer_id
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -379,6 +391,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -406,7 +419,9 @@ class FlashGPT2Model(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
self.step,
|
||||||
)
|
)
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user