fix: add inline comments to highlight differences with llama

This commit is contained in:
drbh 2024-01-24 00:54:47 +00:00
parent c7ad2b61a1
commit 99392376e6

View File

@ -25,10 +25,10 @@ class PhiConfig(PretrainedConfig):
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=32, num_key_value_heads=32,
hidden_act="gelu_fast", hidden_act="gelu_fast", # llama uses silu
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-05, layer_norm_eps=1e-05, # rms in llama
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
@ -37,7 +37,7 @@ class PhiConfig(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0, rope_theta=10000.0,
resid_pdrop=0.1, resid_pdrop=0.1, # llama doesn't have this
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -63,6 +63,7 @@ class PhiConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -104,6 +105,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear( return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize) get_linear(weight, bias=True, quantize=config.quantize)
) )
@ -142,6 +144,7 @@ class FlashPhiAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load( self.dense = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
@ -152,7 +155,6 @@ class FlashPhiAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
self.rotary_emb_dim = 32
def forward( def forward(
self, self,
@ -180,9 +182,13 @@ class FlashPhiAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# NOTE: this is the main difference between Llama and Phi
# in llama the rotary embeddings are applied to the whole query and key.
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place # Apply partial positional embeddings in place
self.rotary_emb( self.rotary_emb(
query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim], query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
cos, sin cos, sin
) )
@ -236,6 +242,7 @@ class PhiMLP(nn.Module):
) )
) )
# llama weights are up_proj and down_proj and bias=False
self.gate_up_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
@ -250,6 +257,8 @@ class PhiMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
post_act = self.act(gate_up_states) post_act = self.act(gate_up_states)
return self.down_proj(post_act) return self.down_proj(post_act)
@ -328,7 +337,7 @@ class FlashPhiModel(torch.nn.Module):
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.ln = FastLayerNorm.load( self.norm = FastLayerNorm.load(
prefix="model.final_layernorm", prefix="model.final_layernorm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
@ -368,8 +377,9 @@ class FlashPhiModel(torch.nn.Module):
max_s, max_s,
) )
normed_hidden_states, _ = self.ln(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return normed_hidden_states
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):