mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix logic.
This commit is contained in:
parent
4e071bf2f1
commit
7fa79f02ca
@ -95,13 +95,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
rotary_ndims = int(self.head_size * rotary_pct)
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights)
|
||||
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
self.rotary_emb.inv_freq = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||
)
|
||||
weights.dtype = dtype
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
|
Loading…
Reference in New Issue
Block a user