mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Supporting code llama.
This commit is contained in:
parent
c4422e5678
commit
3db59bfd00
@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
|
rope_theta=10000.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
self.pretraining_tp = pretraining_tp
|
self.pretraining_tp = pretraining_tp
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
# )
|
# )
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
Loading…
Reference in New Issue
Block a user