From be4d0be8c8607da3922691ff45b09359f1c042e2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 8 Aug 2023 11:59:34 +0200 Subject: [PATCH] Llama change. Reflecting https://github.com/huggingface/transformers/pull/24998 Current status wants to make sure integration tests *are* broken with this. --- .../models/custom_modeling/flash_llama_modeling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 2c22ea46..99ebd425 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -185,8 +185,11 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load( - config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # self.rotary_emb = PositionRotaryEmbedding.load( + # config=config, prefix=f"{prefix}.rotary_emb", weights=weights + # ) + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=config.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size**-0.5