diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6bed3d78..79e1cc16 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -128,10 +128,9 @@ class FlashPhiAttention(torch.nn.Module): # should be 80 = 2560 / 32 self.head_size = self.hidden_size // self.num_heads - # MAYBE (if not static) self.rotary_emb = PositionRotaryEmbedding.static( config=config, - dim=self.head_size, + dim=self.num_heads, base=config.rope_theta, device=weights.device, ) @@ -150,8 +149,6 @@ class FlashPhiAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) - self.dense = TensorParallelRowLinear.load( config, prefix=f"{prefix}.dense", @@ -162,6 +159,28 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + self.rotary_emb_dim = 32 + + # load attention directly from weights + weight = weights.get_tensor(f"{prefix}.q_proj.weight") + bias = weights.get_tensor(f"{prefix}.q_proj.bias") + self.q_proj = nn.Linear(2560, 2560) + self.q_proj.weight = torch.nn.Parameter(weight.contiguous()) + self.q_proj.bias = torch.nn.Parameter(bias.contiguous()) + + self.k_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.k_proj", + weights=weights, + bias=True, + ) + self.v_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.v_proj", + weights=weights, + bias=True, + ) + def forward( self, @@ -175,20 +194,28 @@ class FlashPhiAttention(torch.nn.Module): input_lengths, max_s, ): - qkv = self.query_key_value(hidden_states) + q_len, _ = hidden_states.size() - query, kv = qkv.split( - [ - self.head_size * self.num_heads, - 2 * self.head_size * self.num_key_value_heads, - ], - dim=1, - ) - query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + query = query.view(q_len, 32, self.head_size) + # Pack key and value together + kv = torch.stack([key.view(q_len, 32, self.head_size), value.view(q_len, 32, self.head_size)], dim=1) + + # Apply partial rotary embedding and store the end of the embedding + query_pass = query[:, :, self.rotary_emb_dim:] + key_pass = torch.select(kv, dim=1, index=0)[:, :, self.rotary_emb_dim:] + + # Apply in place positional rotary embeddings self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + # Restore the query and key from the partial rotary embedding + kv[:, 0, :, self.rotary_emb_dim:] = key_pass + query[:, :, self.rotary_emb_dim:] = query_pass + + # Reshape key and value and cache paged_attention.reshape_and_cache( kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) @@ -221,9 +248,7 @@ class FlashPhiAttention(torch.nn.Module): max_s, ) - - return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) - + return self.dense(attn_output.view(q_len, 32*self.head_size)) class PhiMLP(nn.Module): def __init__(self, prefix, config, weights):