This commit is contained in:
Felix Marty 2024-06-13 07:41:46 +00:00
parent 8f1de30b0f
commit 31b8cc4386

View File

@ -175,6 +175,7 @@ class FlashGPT2Attention(torch.nn.Module):
prefix: str, prefix: str,
config, config,
weights, weights,
layer_idx
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -189,6 +190,7 @@ class FlashGPT2Attention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.layer_idx = layer_idx
self.query_key_value = load_qkv( self.query_key_value = load_qkv(
config, config,
@ -218,6 +220,7 @@ class FlashGPT2Attention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
step
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1 self.head_size * self.num_heads, dim=1
@ -226,6 +229,11 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
if self.layer_idx < 5:
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor # output tensor
@ -257,6 +265,9 @@ class FlashGPT2Attention(torch.nn.Module):
max_s, max_s,
) )
if self.layer_idx < 5:
torch.save(attn_output, f"flash_attn_out_step{step}_layer{self.layer_idx}.pt")
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -301,7 +312,7 @@ class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights, layer_idx): def __init__(self, prefix, config, weights, layer_idx):
super().__init__() super().__init__()
self.self_attn = FlashGPT2Attention( self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights prefix=f"{prefix}.attn", config=config, weights=weights, layer_idx=layer_idx
) )
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -342,6 +353,7 @@ class FlashGPT2Layer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
step
) )
if self.layer_idx < 5: if self.layer_idx < 5: