mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: load attn weights to align with flash attn
This commit is contained in:
parent
5db645a19a
commit
8204f23650
@ -128,10 +128,9 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
# should be 80 = 2560 / 32
|
# should be 80 = 2560 / 32
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
# MAYBE (if not static)
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.num_heads,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
@ -150,8 +149,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
|
||||||
|
|
||||||
self.dense = TensorParallelRowLinear.load(
|
self.dense = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
@ -162,6 +159,28 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
self.rotary_emb_dim = 32
|
||||||
|
|
||||||
|
# load attention directly from weights
|
||||||
|
weight = weights.get_tensor(f"{prefix}.q_proj.weight")
|
||||||
|
bias = weights.get_tensor(f"{prefix}.q_proj.bias")
|
||||||
|
self.q_proj = nn.Linear(2560, 2560)
|
||||||
|
self.q_proj.weight = torch.nn.Parameter(weight.contiguous())
|
||||||
|
self.q_proj.bias = torch.nn.Parameter(bias.contiguous())
|
||||||
|
|
||||||
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.k_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.v_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -175,20 +194,28 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query, kv = qkv.split(
|
query = self.q_proj(hidden_states)
|
||||||
[
|
key = self.k_proj(hidden_states)
|
||||||
self.head_size * self.num_heads,
|
value = self.v_proj(hidden_states)
|
||||||
2 * self.head_size * self.num_key_value_heads,
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
|
||||||
|
|
||||||
|
query = query.view(q_len, 32, self.head_size)
|
||||||
|
# Pack key and value together
|
||||||
|
kv = torch.stack([key.view(q_len, 32, self.head_size), value.view(q_len, 32, self.head_size)], dim=1)
|
||||||
|
|
||||||
|
# Apply partial rotary embedding and store the end of the embedding
|
||||||
|
query_pass = query[:, :, self.rotary_emb_dim:]
|
||||||
|
key_pass = torch.select(kv, dim=1, index=0)[:, :, self.rotary_emb_dim:]
|
||||||
|
|
||||||
|
# Apply in place positional rotary embeddings
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
# Restore the query and key from the partial rotary embedding
|
||||||
|
kv[:, 0, :, self.rotary_emb_dim:] = key_pass
|
||||||
|
query[:, :, self.rotary_emb_dim:] = query_pass
|
||||||
|
|
||||||
|
# Reshape key and value and cache
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
@ -221,9 +248,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self.dense(attn_output.view(q_len, 32*self.head_size))
|
||||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
|
||||||
|
|
||||||
|
|
||||||
class PhiMLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
|
Loading…
Reference in New Issue
Block a user