mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: add inline comments to highlight differences with llama
This commit is contained in:
parent
c7ad2b61a1
commit
99392376e6
@ -25,10 +25,10 @@ class PhiConfig(PretrainedConfig):
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="gelu_fast",
|
||||
hidden_act="gelu_fast", # llama uses silu
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-05,
|
||||
layer_norm_eps=1e-05, # rms in llama
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
@ -37,7 +37,7 @@ class PhiConfig(PretrainedConfig):
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
resid_pdrop=0.1,
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -63,6 +63,7 @@ class PhiConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
@ -104,6 +105,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
@ -142,6 +144,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
# in llama the dense layer is called "o_proj" and has bias=False
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.dense",
|
||||
@ -152,7 +155,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
self.rotary_emb_dim = 32
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -180,9 +182,13 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# NOTE: this is the main difference between Llama and Phi
|
||||
# in llama the rotary embeddings are applied to the whole query and key.
|
||||
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
self.rotary_emb(
|
||||
query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim],
|
||||
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
|
||||
cos, sin
|
||||
)
|
||||
|
||||
@ -236,6 +242,7 @@ class PhiMLP(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
# llama weights are up_proj and down_proj and bias=False
|
||||
self.gate_up_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
@ -250,6 +257,8 @@ class PhiMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# NOTE: Llama requires the gate up states to an intermediate size
|
||||
# Phi does not and we can avoid the `view` operation
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
post_act = self.act(gate_up_states)
|
||||
return self.down_proj(post_act)
|
||||
@ -328,7 +337,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
self.ln = FastLayerNorm.load(
|
||||
self.norm = FastLayerNorm.load(
|
||||
prefix="model.final_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
@ -368,8 +377,9 @@ class FlashPhiModel(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
||||
return normed_hidden_states
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
|
Loading…
Reference in New Issue
Block a user