diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index dd466145..e626648f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -24,7 +24,6 @@ class PhiConfig(PretrainedConfig): self, vocab_size=51200, hidden_size=2560, - intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, @@ -40,12 +39,12 @@ class PhiConfig(PretrainedConfig): tie_word_embeddings=False, rope_scaling=None, rope_theta=10000.0, + resid_pdrop=0.1, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size - self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads @@ -61,6 +60,7 @@ class PhiConfig(PretrainedConfig): self.use_cache = use_cache self.rope_scaling = rope_scaling self.rope_theta = rope_theta + self.resid_pdrop = resid_pdrop super().__init__( pad_token_id=pad_token_id, @@ -81,7 +81,7 @@ def load_attention(config, prefix, weights): config, prefix=f"{prefix}.W_pack", weights=weights, - bias=False, + bias=True, ) else: # should be here @@ -90,7 +90,7 @@ def load_attention(config, prefix, weights): prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, - bias=False, + bias=True, ) @@ -116,7 +116,7 @@ def _load_gqa(config, prefix: str, weights): ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) + get_linear(weight, bias=True, quantize=config.quantize) ) @@ -130,6 +130,7 @@ class FlashPhiAttention(torch.nn.Module): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size + # should be 80 = 2560 / 32 self.head_size = self.hidden_size // self.num_heads # MAYBE (if not static) @@ -160,7 +161,7 @@ class FlashPhiAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, - bias=False, + bias=True, ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( @@ -180,6 +181,8 @@ class FlashPhiAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) + # shape = torch.Size([4096, 7680]) + query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -187,6 +190,8 @@ class FlashPhiAttention(torch.nn.Module): ], dim=1, ) + # query = torch.Size([4096, 2560]) + # kv = torch.Size([4096, 5120]) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) @@ -201,6 +206,8 @@ class FlashPhiAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: + print("๐Ÿงข flash attention") + print("cu_seqlen_prefill", cu_seqlen_prefill.shape) # flash attention flash_attn.attention( query, @@ -213,6 +220,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: + print("๐Ÿ“— paged attention") paged_attention.attention( attn_output, query, @@ -225,6 +233,11 @@ class FlashPhiAttention(torch.nn.Module): max_s, ) + # TODO: remove this - only used to summarize attention weights + # get sum of the attention weights + my_sum = torch.sum(attn_output, dim=2) + print("my_sum", my_sum) + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -247,16 +260,17 @@ class PhiMLP(nn.Module): config, prefix=f"{prefix}.fc1", weights=weights, - bias=False, + bias=True, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.fc2", weights=weights, - bias=False, + bias=True, ) def forward(self, hidden_states): + print("FORWARD MLP") gate_up_states = self.gate_up_proj(hidden_states) post_act = self.act(gate_up_states) return self.down_proj(post_act) @@ -274,6 +288,7 @@ class FlashPhiLayer(nn.Module): self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) + self.resid_dropout = torch.nn.Dropout(config.resid_pdrop) def forward( self, @@ -288,11 +303,14 @@ class FlashPhiLayer(nn.Module): input_lengths, max_s, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + print("๐Ÿ’ง FORWARD LAYER") + print("\tinput0", hidden_states[0][1]) + hidden_states, res = self.input_layernorm(hidden_states, residual) + print("\tnormalized shape", hidden_states.shape) # Self Attention attn_output = self.self_attn( - normed_hidden_states, + hidden_states, cos, sin, cu_seqlen_prefill, @@ -303,12 +321,12 @@ class FlashPhiLayer(nn.Module): max_s, ) + attn_output = self.resid_dropout(attn_output) - mlp_output = self.mlp(normed_hidden_states) + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_output + feed_forward_hidden_states - result = attn_output + mlp_output + res - - return result, res + return hidden_states, res class FlashPhiModel(torch.nn.Module): @@ -337,6 +355,12 @@ class FlashPhiModel(torch.nn.Module): self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + self.ln = FastLayerNorm.load( + prefix="model.final_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + def forward( self, input_ids: torch.Tensor, @@ -371,7 +395,8 @@ class FlashPhiModel(torch.nn.Module): max_s, ) - return hidden_states + normed_hidden_states, _ = self.ln(hidden_states, residual) + return normed_hidden_states class FlashPhiForCausalLM(torch.nn.Module): @@ -379,41 +404,22 @@ class FlashPhiForCausalLM(torch.nn.Module): super().__init__() self.model = FlashPhiModel(config, weights) - # self.lm_head = TensorParallelHead.load( - # config, - # prefix="lm_head", - # weights=weights, - # ) - - # TODO: prefer parallel head - self.linear = FastLinear.load( + self.lm_head = TensorParallelHead.load( config, prefix="lm_head", weights=weights, - bias=False, - ) - - # TODO: use in correct place - self.ln = FastLayerNorm.load( - prefix="model.final_layernorm", - weights=weights, - eps=config.rms_norm_eps, ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - # 1000K and 10K - cu_seqlen_prefill: Optional[torch.Tensor], # indexes for the items in the batch + cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - # paged attention related - block_tables: torch.Tensor, # <- indexes into blocks - slots: torch.Tensor, # <- indexes into mem + block_tables: torch.Tensor, + slots: torch.Tensor, input_lengths: torch.Tensor, - # both attentions - max_s: int, # <- max sequence length (make kernals chose swap) - # small opt (only care about final) + max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( @@ -429,7 +435,6 @@ class FlashPhiForCausalLM(torch.nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - normed_hidden_states, res = self.ln(hidden_states, None) - logits = self.linear(normed_hidden_states) + logits = self.lm_head(hidden_states) return logits