Llama change.

Reflecting
https://github.com/huggingface/transformers/pull/24998

Current status wants to make sure integration tests *are* broken with
this.
This commit is contained in:
Nicolas Patry 2023-08-08 11:59:34 +02:00
parent 1fdc88ee90
commit be4d0be8c8

View File

@ -185,8 +185,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# self.rotary_emb = PositionRotaryEmbedding.load(
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=config.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size**-0.5