fix: adjust model config vars and other refactors

This commit is contained in:
drbh 2024-01-24 12:23:04 -05:00
parent 99392376e6
commit 9bcd21a0b0

View File

@ -38,6 +38,7 @@ class PhiConfig(PretrainedConfig):
rope_scaling=None,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5,
**kwargs,
):
self.vocab_size = vocab_size
@ -54,6 +55,7 @@ class PhiConfig(PretrainedConfig):
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__(
pad_token_id=pad_token_id,
@ -67,14 +69,6 @@ class PhiConfig(PretrainedConfig):
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=True,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
@ -130,6 +124,7 @@ class FlashPhiAttention(torch.nn.Module):
)
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -188,7 +183,7 @@ class FlashPhiAttention(torch.nn.Module):
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
cos, sin
)
@ -243,7 +238,7 @@ class PhiMLP(nn.Module):
)
# llama weights are up_proj and down_proj and bias=False
self.gate_up_proj = TensorParallelRowLinear.load(
self.up_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
@ -259,9 +254,7 @@ class PhiMLP(nn.Module):
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
gate_up_states = self.gate_up_proj(hidden_states)
post_act = self.act(gate_up_states)
return self.down_proj(post_act)
return self.down_proj(self.act(self.up_proj(hidden_states)))
class FlashPhiLayer(nn.Module):
@ -304,10 +297,7 @@ class FlashPhiLayer(nn.Module):
max_s,
)
attn_output = self.resid_dropout(attn_output)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_output + feed_forward_hidden_states
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
return hidden_states, res