diff --git a/.gitignore b/.gitignore index 19604d42..4f8f7b87 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea target router/tokenizer.json +.*__pycache__.* diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index ff9b9763..1076126b 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] assert len(generated_texts) == 4 - assert generated_texts, all( + assert all( [text == generated_texts[0] for text in generated_texts] - ) + ), generated_texts assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d60fb848..64bd3a40 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -94,9 +94,6 @@ class FlashNeoxAttention(torch.nn.Module): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv(