mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix PositionalRotary loads.
This commit is contained in:
parent
165bb4b6c0
commit
4e071bf2f1
@ -103,7 +103,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
@ -93,7 +93,15 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
rotary_ndims = int(self.head_size * rotary_pct)
|
||||
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
self.rotary_emb.inv_freq = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||
)
|
||||
weights.dtype = dtype
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
|
@ -297,7 +297,27 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
class PositionRotaryEmbedding(RotaryEmbedding):
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
@staticmethod
|
||||
def load(prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return PositionRotaryEmbedding(inv_freq)
|
||||
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
|
Loading…
Reference in New Issue
Block a user