mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: prefer parallel attn load and small refactors
This commit is contained in:
parent
8204f23650
commit
2b43c5b0dd
@ -125,7 +125,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
# should be 80 = 2560 / 32
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
@ -149,6 +148,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.dense",
|
||||
@ -161,25 +162,6 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
).repeat_interleave(self.num_groups)
|
||||
self.rotary_emb_dim = 32
|
||||
|
||||
# load attention directly from weights
|
||||
weight = weights.get_tensor(f"{prefix}.q_proj.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.q_proj.bias")
|
||||
self.q_proj = nn.Linear(2560, 2560)
|
||||
self.q_proj.weight = torch.nn.Parameter(weight.contiguous())
|
||||
self.q_proj.bias = torch.nn.Parameter(bias.contiguous())
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.v_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -194,15 +176,19 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
# Compute query, key, value and split
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
query = self.q_proj(hidden_states)
|
||||
key = self.k_proj(hidden_states)
|
||||
value = self.v_proj(hidden_states)
|
||||
|
||||
query = query.view(q_len, 32, self.head_size)
|
||||
# Pack key and value together
|
||||
kv = torch.stack([key.view(q_len, 32, self.head_size), value.view(q_len, 32, self.head_size)], dim=1)
|
||||
# Reshape query and key for rotary embeddings
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# Apply partial rotary embedding and store the end of the embedding
|
||||
query_pass = query[:, :, self.rotary_emb_dim:]
|
||||
@ -248,7 +234,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(q_len, 32*self.head_size))
|
||||
return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
|
Loading…
Reference in New Issue
Block a user