diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7c7096f7..43eced57 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -25,10 +25,10 @@ class PhiConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, - hidden_act="gelu_fast", + hidden_act="gelu_fast", # llama uses silu max_position_embeddings=2048, initializer_range=0.02, - layer_norm_eps=1e-05, + layer_norm_eps=1e-05, # rms in llama use_cache=True, pad_token_id=0, bos_token_id=1, @@ -37,7 +37,7 @@ class PhiConfig(PretrainedConfig): tie_word_embeddings=False, rope_scaling=None, rope_theta=10000.0, - resid_pdrop=0.1, + resid_pdrop=0.1, # llama doesn't have this **kwargs, ): self.vocab_size = vocab_size @@ -63,6 +63,7 @@ class PhiConfig(PretrainedConfig): **kwargs, ) +# this is the same as llama except for Phi uses bias=True def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) @@ -104,6 +105,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + # this is the same as llama except for Phi uses bias=True return TensorParallelColumnLinear( get_linear(weight, bias=True, quantize=config.quantize) ) @@ -142,6 +144,7 @@ class FlashPhiAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( config, prefix=f"{prefix}.dense", @@ -152,7 +155,6 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - self.rotary_emb_dim = 32 def forward( self, @@ -180,9 +182,13 @@ class FlashPhiAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + # NOTE: this is the main difference between Llama and Phi + # in llama the rotary embeddings are applied to the whole query and key. + # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions + # # Apply partial positional embeddings in place self.rotary_emb( - query[:, :, :self.rotary_emb_dim], kv[:, 0, :, :self.rotary_emb_dim], + query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads], cos, sin ) @@ -236,6 +242,7 @@ class PhiMLP(nn.Module): ) ) + # llama weights are up_proj and down_proj and bias=False self.gate_up_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.fc1", @@ -250,6 +257,8 @@ class PhiMLP(nn.Module): ) def forward(self, hidden_states): + # NOTE: Llama requires the gate up states to an intermediate size + # Phi does not and we can avoid the `view` operation gate_up_states = self.gate_up_proj(hidden_states) post_act = self.act(gate_up_states) return self.down_proj(post_act) @@ -328,7 +337,7 @@ class FlashPhiModel(torch.nn.Module): self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads - self.ln = FastLayerNorm.load( + self.norm = FastLayerNorm.load( prefix="model.final_layernorm", weights=weights, eps=config.layer_norm_eps, @@ -368,8 +377,9 @@ class FlashPhiModel(torch.nn.Module): max_s, ) - normed_hidden_states, _ = self.ln(hidden_states, residual) - return normed_hidden_states + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states class FlashPhiForCausalLM(torch.nn.Module): def __init__(self, config, weights):