fix: remove debug logs

This commit is contained in:
drbh 2024-01-19 00:12:58 +00:00
parent 43441cad42
commit 5db645a19a

View File

@ -26,11 +26,11 @@ class PhiConfig(PretrainedConfig):
hidden_size=2560, hidden_size=2560,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=None, num_key_value_heads=32,
hidden_act="gelu_fast", hidden_act="gelu_fast",
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-6, layer_norm_eps=1e-05,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
@ -47,15 +47,10 @@ class PhiConfig(PretrainedConfig):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps self.layer_norm_eps = layer_norm_eps
self.pretraining_tp = pretraining_tp self.pretraining_tp = pretraining_tp
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
@ -181,7 +176,6 @@ class FlashPhiAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
# shape = torch.Size([4096, 7680])
query, kv = qkv.split( query, kv = qkv.split(
[ [
@ -190,8 +184,6 @@ class FlashPhiAttention(torch.nn.Module):
], ],
dim=1, dim=1,
) )
# query = torch.Size([4096, 2560])
# kv = torch.Size([4096, 5120])
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
@ -206,9 +198,6 @@ class FlashPhiAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
print("🧢 flash attention")
print("cu_seqlen_prefill", cu_seqlen_prefill.shape)
# flash attention
flash_attn.attention( flash_attn.attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
@ -220,7 +209,6 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
print("📗 paged attention")
paged_attention.attention( paged_attention.attention(
attn_output, attn_output,
query, query,
@ -233,10 +221,6 @@ class FlashPhiAttention(torch.nn.Module):
max_s, max_s,
) )
# TODO: remove this - only used to summarize attention weights
# get sum of the attention weights
my_sum = torch.sum(attn_output, dim=2)
print("my_sum", my_sum)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -270,7 +254,6 @@ class PhiMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
print("FORWARD MLP")
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
post_act = self.act(gate_up_states) post_act = self.act(gate_up_states)
return self.down_proj(post_act) return self.down_proj(post_act)
@ -284,9 +267,8 @@ class FlashPhiLayer(nn.Module):
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
@ -303,11 +285,7 @@ class FlashPhiLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
print("💧 FORWARD LAYER")
print("\tinput0", hidden_states[0][1])
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
print("\tnormalized shape", hidden_states.shape)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
hidden_states, hidden_states,
@ -358,7 +336,7 @@ class FlashPhiModel(torch.nn.Module):
self.ln = FastLayerNorm.load( self.ln = FastLayerNorm.load(
prefix="model.final_layernorm", prefix="model.final_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.layer_norm_eps,
) )
def forward( def forward(