Supporting code llama.

This commit is contained in:
Nicolas Patry 2023-08-24 18:52:20 +02:00
parent c4422e5678
commit 3db59bfd00

View File

@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
self.pretraining_tp = pretraining_tp self.pretraining_tp = pretraining_tp
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights # config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# ) # )
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5