mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fix PositionalRotary loads.
This commit is contained in:
parent
165bb4b6c0
commit
4e071bf2f1
@ -103,7 +103,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
@ -93,7 +93,15 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
rotary_ndims = int(self.head_size * rotary_pct)
|
rotary_ndims = int(self.head_size * rotary_pct)
|
||||||
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||||
|
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
self.rotary_emb.inv_freq = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||||
|
)
|
||||||
|
weights.dtype = dtype
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.query_key_value = load_qkv(
|
self.query_key_value = load_qkv(
|
||||||
|
@ -297,7 +297,27 @@ try:
|
|||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
|
||||||
class PositionRotaryEmbedding(RotaryEmbedding):
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, inv_freq):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(prefix, weights):
|
||||||
|
# XXX: Always load this in float32 !
|
||||||
|
dtype = weights.dtype
|
||||||
|
weights.dtype = torch.float32
|
||||||
|
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||||
|
weights.dtype = dtype
|
||||||
|
return PositionRotaryEmbedding(inv_freq)
|
||||||
|
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
Loading…
Reference in New Issue
Block a user