From 2b43c5b0dd1648ba8e06cccb2e03be3a8d96c919 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 23 Jan 2024 00:14:22 +0000 Subject: [PATCH] fix: prefer parallel attn load and small refactors --- .../custom_modeling/flash_phi_modeling.py | 44 +++++++------------ 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 79e1cc16..b49d0985 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -125,7 +125,6 @@ class FlashPhiAttention(torch.nn.Module): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - # should be 80 = 2560 / 32 self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( @@ -149,6 +148,8 @@ class FlashPhiAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) + self.query_key_value = load_attention(config, prefix, weights) + self.dense = TensorParallelRowLinear.load( config, prefix=f"{prefix}.dense", @@ -161,25 +162,6 @@ class FlashPhiAttention(torch.nn.Module): ).repeat_interleave(self.num_groups) self.rotary_emb_dim = 32 - # load attention directly from weights - weight = weights.get_tensor(f"{prefix}.q_proj.weight") - bias = weights.get_tensor(f"{prefix}.q_proj.bias") - self.q_proj = nn.Linear(2560, 2560) - self.q_proj.weight = torch.nn.Parameter(weight.contiguous()) - self.q_proj.bias = torch.nn.Parameter(bias.contiguous()) - - self.k_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=True, - ) - self.v_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=True, - ) def forward( @@ -194,15 +176,19 @@ class FlashPhiAttention(torch.nn.Module): input_lengths, max_s, ): - q_len, _ = hidden_states.size() + # Compute query, key, value and split + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) - query = self.q_proj(hidden_states) - key = self.k_proj(hidden_states) - value = self.v_proj(hidden_states) - - query = query.view(q_len, 32, self.head_size) - # Pack key and value together - kv = torch.stack([key.view(q_len, 32, self.head_size), value.view(q_len, 32, self.head_size)], dim=1) + # Reshape query and key for rotary embeddings + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) # Apply partial rotary embedding and store the end of the embedding query_pass = query[:, :, self.rotary_emb_dim:] @@ -248,7 +234,7 @@ class FlashPhiAttention(torch.nn.Module): max_s, ) - return self.dense(attn_output.view(q_len, 32*self.head_size)) + return self.dense(attn_output.view(-1, self.num_heads*self.head_size)) class PhiMLP(nn.Module): def __init__(self, prefix, config, weights):