From 742ef9b8e57d24a3a7778d94d1dbd7dc2321ddd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 24 May 2024 15:34:42 +0000 Subject: [PATCH] Fix (flash) Gemma prefix and enable tests --- integration-tests/models/test_flash_gemma.py | 5 +---- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- server/text_generation_server/models/flash_gemma.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index 2822b5e2..7ab43111 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_gemma_handle(launcher): - with launcher("gg-hf/gemma-2b", num_shard=1) as handle: + with launcher("google/gemma-2b", num_shard=1) as handle: yield handle @@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle): return flash_gemma_handle.client -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma(flash_gemma, response_snapshot): @@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot): assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_all_params(flash_gemma, response_snapshot): @@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ac6fd0e6..cff4b5d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): super().__init__() embed_norm = config.hidden_size**0.5 - if prefix is None: + if not prefix: prefix = "model" else: prefix = f"{prefix}.model" diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 53bfd064..358883e6 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM): weights._set_gptq_params(model_id, revision) # TODO hardcoded - prefix = "language_model" + prefix = "" model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) torch.distributed.barrier(group=self.process_group)