From 31b8cc43865a9afb00564571fefeb3d23918e595 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 13 Jun 2024 07:41:46 +0000 Subject: [PATCH] debug --- .../models/custom_modeling/flash_gpt2_modeling.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a09af3fe..ff25861e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -175,6 +175,7 @@ class FlashGPT2Attention(torch.nn.Module): prefix: str, config, weights, + layer_idx ): super().__init__() self.num_heads = config.num_attention_heads @@ -189,6 +190,7 @@ class FlashGPT2Attention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() + self.layer_idx = layer_idx self.query_key_value = load_qkv( config, @@ -218,6 +220,7 @@ class FlashGPT2Attention(torch.nn.Module): slots, input_lengths, max_s, + step ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -225,6 +228,11 @@ class FlashGPT2Attention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) + + if self.layer_idx < 5: + torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt") + torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt") + torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt") reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) @@ -256,6 +264,9 @@ class FlashGPT2Attention(torch.nn.Module): input_lengths, max_s, ) + + if self.layer_idx < 5: + torch.save(attn_output, f"flash_attn_out_step{step}_layer{self.layer_idx}.pt") return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -301,7 +312,7 @@ class FlashGPT2Layer(nn.Module): def __init__(self, prefix, config, weights, layer_idx): super().__init__() self.self_attn = FlashGPT2Attention( - prefix=f"{prefix}.attn", config=config, weights=weights + prefix=f"{prefix}.attn", config=config, weights=weights, layer_idx=layer_idx ) self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -342,6 +353,7 @@ class FlashGPT2Layer(nn.Module): slots, input_lengths, max_s, + step ) if self.layer_idx < 5: