From 9bcd21a0b027d1fd62782cefa650394b668c8a64 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jan 2024 12:23:04 -0500 Subject: [PATCH] fix: adjust model config vars and other refactors --- .../custom_modeling/flash_phi_modeling.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 43eced57..9f33143f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -38,6 +38,7 @@ class PhiConfig(PretrainedConfig): rope_scaling=None, rope_theta=10000.0, resid_pdrop=0.1, # llama doesn't have this + partial_rotary_factor=0.5, **kwargs, ): self.vocab_size = vocab_size @@ -54,6 +55,7 @@ class PhiConfig(PretrainedConfig): self.rope_scaling = rope_scaling self.rope_theta = rope_theta self.resid_pdrop = resid_pdrop + self.partial_rotary_factor = partial_rotary_factor super().__init__( pad_token_id=pad_token_id, @@ -68,21 +70,13 @@ def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: - if config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.W_pack", - weights=weights, - bias=True, - ) - else: - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=True, - ) + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 @@ -130,6 +124,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = int(config.partial_rotary_factor * self.head_size) if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -188,7 +183,7 @@ class FlashPhiAttention(torch.nn.Module): # # Apply partial positional embeddings in place self.rotary_emb( - query[:, :, :self.num_heads], kv[:, 0, :, :self.num_heads], + query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim], cos, sin ) @@ -243,7 +238,7 @@ class PhiMLP(nn.Module): ) # llama weights are up_proj and down_proj and bias=False - self.gate_up_proj = TensorParallelRowLinear.load( + self.up_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.fc1", weights=weights, @@ -259,9 +254,7 @@ class PhiMLP(nn.Module): def forward(self, hidden_states): # NOTE: Llama requires the gate up states to an intermediate size # Phi does not and we can avoid the `view` operation - gate_up_states = self.gate_up_proj(hidden_states) - post_act = self.act(gate_up_states) - return self.down_proj(post_act) + return self.down_proj(self.act(self.up_proj(hidden_states))) class FlashPhiLayer(nn.Module): @@ -304,10 +297,7 @@ class FlashPhiLayer(nn.Module): max_s, ) - attn_output = self.resid_dropout(attn_output) - - feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - hidden_states = attn_output + feed_forward_hidden_states + hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states))) return hidden_states, res