diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index e626648f..6bed3d78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -26,11 +26,11 @@ class PhiConfig(PretrainedConfig): hidden_size=2560, num_hidden_layers=32, num_attention_heads=32, - num_key_value_heads=None, + num_key_value_heads=32, hidden_act="gelu_fast", max_position_embeddings=2048, initializer_range=0.02, - rms_norm_eps=1e-6, + layer_norm_eps=1e-05, use_cache=True, pad_token_id=0, bos_token_id=1, @@ -47,15 +47,10 @@ class PhiConfig(PretrainedConfig): self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps + self.layer_norm_eps = layer_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_scaling = rope_scaling @@ -181,7 +176,6 @@ class FlashPhiAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - # shape = torch.Size([4096, 7680]) query, kv = qkv.split( [ @@ -190,8 +184,6 @@ class FlashPhiAttention(torch.nn.Module): ], dim=1, ) - # query = torch.Size([4096, 2560]) - # kv = torch.Size([4096, 5120]) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) @@ -206,9 +198,6 @@ class FlashPhiAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - print("๐Ÿงข flash attention") - print("cu_seqlen_prefill", cu_seqlen_prefill.shape) - # flash attention flash_attn.attention( query, torch.select(kv, dim=1, index=0), @@ -220,7 +209,6 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - print("๐Ÿ“— paged attention") paged_attention.attention( attn_output, query, @@ -233,10 +221,6 @@ class FlashPhiAttention(torch.nn.Module): max_s, ) - # TODO: remove this - only used to summarize attention weights - # get sum of the attention weights - my_sum = torch.sum(attn_output, dim=2) - print("my_sum", my_sum) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -270,7 +254,6 @@ class PhiMLP(nn.Module): ) def forward(self, hidden_states): - print("FORWARD MLP") gate_up_states = self.gate_up_proj(hidden_states) post_act = self.act(gate_up_states) return self.down_proj(post_act) @@ -284,9 +267,8 @@ class FlashPhiLayer(nn.Module): prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - - self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps ) self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) @@ -303,11 +285,7 @@ class FlashPhiLayer(nn.Module): input_lengths, max_s, ): - print("๐Ÿ’ง FORWARD LAYER") - print("\tinput0", hidden_states[0][1]) hidden_states, res = self.input_layernorm(hidden_states, residual) - print("\tnormalized shape", hidden_states.shape) - # Self Attention attn_output = self.self_attn( hidden_states, @@ -358,7 +336,7 @@ class FlashPhiModel(torch.nn.Module): self.ln = FastLayerNorm.load( prefix="model.final_layernorm", weights=weights, - eps=config.rms_norm_eps, + eps=config.layer_norm_eps, ) def forward(