Fix PositionalRotary loads.

This commit is contained in:
Ubuntu 2023-05-25 09:34:31 +00:00 committed by Nicolas Patry
parent 165bb4b6c0
commit 4e071bf2f1
3 changed files with 32 additions and 3 deletions

View File

@ -103,7 +103,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads // weights.process_group.size()

View File

@ -93,7 +93,15 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
dtype = weights.dtype
weights.dtype = torch.float32
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
weights.dtype = dtype
self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv(

View File

@ -297,7 +297,27 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(RotaryEmbedding):
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
@staticmethod
def load(prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return PositionRotaryEmbedding(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)