From e605c2a43e693844cb2c5ba879f41392faf64793 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 24 Aug 2023 18:54:47 +0200 Subject: [PATCH] Supporting code llama. (#918) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../models/custom_modeling/flash_llama_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d0185ede..f0e1236d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig): pretraining_tp=1, tie_word_embeddings=False, rope_scaling=None, + rope_theta=10000.0, **kwargs, ): self.vocab_size = vocab_size @@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig): self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_scaling = rope_scaling + self.rope_theta = rope_theta super().__init__( pad_token_id=pad_token_id, @@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module): # config=config, prefix=f"{prefix}.rotary_emb", weights=weights # ) self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=10000.0, device=weights.device + config=config, dim=self.head_size, base=config.rope_theta, device=weights.device ) self.softmax_scale = self.head_size**-0.5