mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: add inline comments to highlight differences with llama
This commit is contained in:
parent
c7ad2b61a1
commit
99392376e6
@ -25,10 +25,10 @@ class PhiConfig(PretrainedConfig):
|
|||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=32,
|
num_key_value_heads=32,
|
||||||
hidden_act="gelu_fast",
|
hidden_act="gelu_fast", # llama uses silu
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-05,
|
layer_norm_eps=1e-05, # rms in llama
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
@ -37,7 +37,7 @@ class PhiConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
resid_pdrop=0.1,
|
resid_pdrop=0.1, # llama doesn't have this
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -63,6 +63,7 @@ class PhiConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this is the same as llama except for Phi uses bias=True
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
@ -104,6 +105,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
# this is the same as llama except for Phi uses bias=True
|
||||||
return TensorParallelColumnLinear(
|
return TensorParallelColumnLinear(
|
||||||
get_linear(weight, bias=True, quantize=config.quantize)
|
get_linear(weight, bias=True, quantize=config.quantize)
|
||||||
)
|
)
|
||||||
@ -142,6 +144,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
# in llama the dense layer is called "o_proj" and has bias=False
|
||||||
self.dense = TensorParallelRowLinear.load(
|
self.dense = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
@ -152,7 +155,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
self.rotary_emb_dim = 32
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -180,9 +182,13 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
# NOTE: this is the main difference between Llama and Phi
|
||||||
|
# in llama the rotary embeddings are applied to the whole query and key.
|
||||||
|
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||||
|
#
|
||||||
# Apply partial positional embeddings in place
|
# Apply partial positional embeddings in place
|
||||||
self.rotary_emb(
|
self.rotary_emb(
|
||||||
query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim],
|
query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads],
|
||||||
cos, sin
|
cos, sin
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -236,6 +242,7 @@ class PhiMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# llama weights are up_proj and down_proj and bias=False
|
||||||
self.gate_up_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
@ -250,6 +257,8 @@ class PhiMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
# NOTE: Llama requires the gate up states to an intermediate size
|
||||||
|
# Phi does not and we can avoid the `view` operation
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
post_act = self.act(gate_up_states)
|
post_act = self.act(gate_up_states)
|
||||||
return self.down_proj(post_act)
|
return self.down_proj(post_act)
|
||||||
@ -328,7 +337,7 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
self.ln = FastLayerNorm.load(
|
self.norm = FastLayerNorm.load(
|
||||||
prefix="model.final_layernorm",
|
prefix="model.final_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
@ -368,8 +377,9 @@ class FlashPhiModel(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
normed_hidden_states, _ = self.ln(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return normed_hidden_states
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
class FlashPhiForCausalLM(torch.nn.Module):
|
class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
|
Loading…
Reference in New Issue
Block a user