Fix logic.

This commit is contained in:
Ubuntu 2023-05-25 09:42:59 +00:00 committed by Nicolas Patry
parent 4e071bf2f1
commit 7fa79f02ca

View File

@ -95,13 +95,6 @@ class FlashNeoxAttention(torch.nn.Module):
rotary_ndims = int(self.head_size * rotary_pct) rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
dtype = weights.dtype
weights.dtype = torch.float32
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
weights.dtype = dtype
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv( self.query_key_value = load_qkv(