This commit is contained in:
Nicolas Patry 2023-08-08 10:43:34 +00:00
parent be4d0be8c8
commit fc7221369e

View File

@ -189,7 +189,7 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=config.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size**-0.5