fix: adjust model config vars and other refactors

This commit is contained in:
drbh 2024-01-24 12:23:04 -05:00
parent 99392376e6
commit 9bcd21a0b0

View File

@ -38,6 +38,7 @@ class PhiConfig(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0, rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -54,6 +55,7 @@ class PhiConfig(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -68,21 +70,13 @@ def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
else: else:
if config.model_type == "baichuan": return TensorParallelColumnLinear.load_multi(
return TensorParallelColumnLinear.load_qkv( config,
config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
prefix=f"{prefix}.W_pack", dim=0,
weights=weights, weights=weights,
bias=True, bias=True,
) )
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0 assert config.hidden_size % config.num_attention_heads == 0
@ -130,6 +124,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -188,7 +183,7 @@ class FlashPhiAttention(torch.nn.Module):
# #
# Apply partial positional embeddings in place # Apply partial positional embeddings in place
self.rotary_emb( self.rotary_emb(
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads], query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
cos, sin cos, sin
) )
@ -243,7 +238,7 @@ class PhiMLP(nn.Module):
) )
# llama weights are up_proj and down_proj and bias=False # llama weights are up_proj and down_proj and bias=False
self.gate_up_proj = TensorParallelRowLinear.load( self.up_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
weights=weights, weights=weights,
@ -259,9 +254,7 @@ class PhiMLP(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size # NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation # Phi does not and we can avoid the `view` operation
gate_up_states = self.gate_up_proj(hidden_states) return self.down_proj(self.act(self.up_proj(hidden_states)))
post_act = self.act(gate_up_states)
return self.down_proj(post_act)
class FlashPhiLayer(nn.Module): class FlashPhiLayer(nn.Module):
@ -304,10 +297,7 @@ class FlashPhiLayer(nn.Module):
max_s, max_s,
) )
attn_output = self.resid_dropout(attn_output) hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_output + feed_forward_hidden_states
return hidden_states, res return hidden_states, res