fix: improve model initalization

This commit is contained in:
drbh 2024-01-18 19:36:50 +00:00
parent 77ee1f18fa
commit 43441cad42

View File

@ -24,7 +24,6 @@ class PhiConfig(PretrainedConfig):
self, self,
vocab_size=51200, vocab_size=51200,
hidden_size=2560, hidden_size=2560,
intermediate_size=11008,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=None, num_key_value_heads=None,
@ -40,12 +39,12 @@ class PhiConfig(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0, rope_theta=10000.0,
resid_pdrop=0.1,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
@ -61,6 +60,7 @@ class PhiConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -81,7 +81,7 @@ def load_attention(config, prefix, weights):
config, config,
prefix=f"{prefix}.W_pack", prefix=f"{prefix}.W_pack",
weights=weights, weights=weights,
bias=False, bias=True,
) )
else: else:
# should be here # should be here
@ -90,7 +90,7 @@ def load_attention(config, prefix, weights):
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=False, bias=True,
) )
@ -116,7 +116,7 @@ def _load_gqa(config, prefix: str, weights):
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize) get_linear(weight, bias=True, quantize=config.quantize)
) )
@ -130,6 +130,7 @@ class FlashPhiAttention(torch.nn.Module):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# should be 80 = 2560 / 32
self.head_size = self.hidden_size // self.num_heads self.head_size = self.hidden_size // self.num_heads
# MAYBE (if not static) # MAYBE (if not static)
@ -160,7 +161,7 @@ class FlashPhiAttention(torch.nn.Module):
config, config,
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
weights=weights, weights=weights,
bias=False, bias=True,
) )
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
@ -180,6 +181,8 @@ class FlashPhiAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
# shape = torch.Size([4096, 7680])
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -187,6 +190,8 @@ class FlashPhiAttention(torch.nn.Module):
], ],
dim=1, dim=1,
) )
# query = torch.Size([4096, 2560])
# kv = torch.Size([4096, 5120])
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
@ -201,6 +206,8 @@ class FlashPhiAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
print("🧢 flash attention")
print("cu_seqlen_prefill", cu_seqlen_prefill.shape)
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
query, query,
@ -213,6 +220,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
print("📗 paged attention")
paged_attention.attention( paged_attention.attention(
attn_output, attn_output,
query, query,
@ -225,6 +233,11 @@ class FlashPhiAttention(torch.nn.Module):
max_s, max_s,
) )
# TODO: remove this - only used to summarize attention weights
# get sum of the attention weights
my_sum = torch.sum(attn_output, dim=2)
print("my_sum", my_sum)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -247,16 +260,17 @@ class PhiMLP(nn.Module):
config, config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
weights=weights, weights=weights,
bias=False, bias=True,
) )
self.down_proj = TensorParallelRowLinear.load( self.down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.fc2", prefix=f"{prefix}.fc2",
weights=weights, weights=weights,
bias=False, bias=True,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
print("FORWARD MLP")
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
post_act = self.act(gate_up_states) post_act = self.act(gate_up_states)
return self.down_proj(post_act) return self.down_proj(post_act)
@ -274,6 +288,7 @@ class FlashPhiLayer(nn.Module):
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward( def forward(
self, self,
@ -288,11 +303,14 @@ class FlashPhiLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) print("💧 FORWARD LAYER")
print("\tinput0", hidden_states[0][1])
hidden_states, res = self.input_layernorm(hidden_states, residual)
print("\tnormalized shape", hidden_states.shape)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
normed_hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
@ -303,12 +321,12 @@ class FlashPhiLayer(nn.Module):
max_s, max_s,
) )
attn_output = self.resid_dropout(attn_output)
mlp_output = self.mlp(normed_hidden_states) feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_output + feed_forward_hidden_states
result = attn_output + mlp_output + res return hidden_states, res
return result, res
class FlashPhiModel(torch.nn.Module): class FlashPhiModel(torch.nn.Module):
@ -337,6 +355,12 @@ class FlashPhiModel(torch.nn.Module):
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.ln = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -371,7 +395,8 @@ class FlashPhiModel(torch.nn.Module):
max_s, max_s,
) )
return hidden_states normed_hidden_states, _ = self.ln(hidden_states, residual)
return normed_hidden_states
class FlashPhiForCausalLM(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module):
@ -379,41 +404,22 @@ class FlashPhiForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) self.model = FlashPhiModel(config, weights)
# self.lm_head = TensorParallelHead.load( self.lm_head = TensorParallelHead.load(
# config,
# prefix="lm_head",
# weights=weights,
# )
# TODO: prefer parallel head
self.linear = FastLinear.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,
bias=False,
)
# TODO: use in correct place
self.ln = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.rms_norm_eps,
) )
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
# 1000K and 10K cu_seqlen_prefill: Optional[torch.Tensor],
cu_seqlen_prefill: Optional[torch.Tensor], # indexes for the items in the batch
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
# paged attention related block_tables: torch.Tensor,
block_tables: torch.Tensor, # <- indexes into blocks slots: torch.Tensor,
slots: torch.Tensor, # <- indexes into mem
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
# both attentions max_s: int,
max_s: int, # <- max sequence length (make kernals chose swap)
# small opt (only care about final)
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
@ -429,7 +435,6 @@ class FlashPhiForCausalLM(torch.nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
normed_hidden_states, res = self.ln(hidden_states, None) logits = self.lm_head(hidden_states)
logits = self.linear(normed_hidden_states)
return logits return logits