mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Fix logic.
This commit is contained in:
parent
4e071bf2f1
commit
7fa79f02ca
@ -95,13 +95,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
rotary_ndims = int(self.head_size * rotary_pct)
|
rotary_ndims = int(self.head_size * rotary_pct)
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||||
|
|
||||||
dtype = weights.dtype
|
|
||||||
weights.dtype = torch.float32
|
|
||||||
self.rotary_emb.inv_freq = nn.Parameter(
|
|
||||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
|
||||||
)
|
|
||||||
weights.dtype = dtype
|
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.query_key_value = load_qkv(
|
self.query_key_value = load_qkv(
|
||||||
|
Loading…
Reference in New Issue
Block a user