From 8f1de30b0ffc03c49ea4f2765fe68ba740675f85 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 13 Jun 2024 07:31:11 +0000 Subject: [PATCH] debug --- .../custom_modeling/flash_gpt2_modeling.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0c01f56a..a09af3fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -298,7 +298,7 @@ class GPT2MLP(nn.Module): class FlashGPT2Layer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_idx): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights @@ -313,6 +313,7 @@ class FlashGPT2Layer(nn.Module): weights=weights, eps=config.layer_norm_epsilon, ) + self.layer_idx = layer_idx def forward( self, @@ -324,10 +325,14 @@ class FlashGPT2Layer(nn.Module): slots, input_lengths, max_s, + step, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + if self.layer_idx < 5: + torch.save(hidden_states, f"hidden_states_bef_attn_step{step}_layer{self.layer_idx}.pt") + # Self Attention attn_output = self.self_attn( hidden_states, @@ -339,6 +344,9 @@ class FlashGPT2Layer(nn.Module): max_s, ) + if self.layer_idx < 5: + torch.save(attn_output, f"attn_output_step{step}_layer{self.layer_idx}.pt") + hidden_states = attn_output + residual residual = hidden_states @@ -346,6 +354,9 @@ class FlashGPT2Layer(nn.Module): mlp_output = self.mlp(hidden_states) + if self.layer_idx < 5: + torch.save(mlp_output, f"mlp_output_step{step}_layer{self.layer_idx}.pt") + return residual + mlp_output, residual @@ -364,6 +375,7 @@ class FlashGPT2Model(torch.nn.Module): ), config=config, weights=weights, + layer_idx=layer_id ) for layer_id in range(config.num_hidden_layers) ] @@ -379,6 +391,7 @@ class FlashGPT2Model(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads + self.step = 0 def forward( self, @@ -406,7 +419,9 @@ class FlashGPT2Model(torch.nn.Module): slots, input_lengths, max_s, + self.step, ) + self.step += 1 hidden_states = self.norm(hidden_states)