From 767d21cbf8814c8ac3976fcec95033300a593e89 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:09:21 +0100 Subject: [PATCH] add integration tests --- integration-tests/conftest.py | 3 + .../test_flash_gemma/test_flash_gemma.json | 89 +++++ .../test_flash_gemma_all_params.json | 89 +++++ .../test_flash_gemma_load.json | 358 ++++++++++++++++++ integration-tests/models/test_flash_gemma.py | 58 +++ .../custom_modeling/flash_gemma_modeling.py | 8 +- 6 files changed, 599 insertions(+), 6 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json create mode 100644 integration-tests/models/test_flash_gemma.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e0228894..80457bc2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension): exclude=None, matcher=None, ): + if isinstance(data, Response): + data = data.dict() + if isinstance(data, List): data = [d.dict() for d in data] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json new file mode 100644 index 00000000..80f0d053 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.8671875, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.4375, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8203125, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23242188, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.08544922, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.9375, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.671875, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.40429688, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.1875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json new file mode 100644 index 00000000..8253dc96 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 7539, + "logprob": -0.73046875, + "special": false, + "text": " forms" + }, + { + "id": 708, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 671, + "logprob": -1.703125, + "special": false, + "text": " an" + }, + { + "id": 8727, + "logprob": 0.0, + "special": false, + "text": " essential" + }, + { + "id": 1702, + "logprob": 0.0, + "special": false, + "text": " part" + }, + { + "id": 576, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 11859, + "logprob": -1.6953125, + "special": false, + "text": " lab" + }, + { + "id": 2185, + "logprob": -1.3125, + "special": false, + "text": " process" + }, + { + "id": 578, + "logprob": -1.5, + "special": false, + "text": " and" + } + ], + "top_tokens": null + }, + "generated_text": "Test request forms are an essential part of the lab process and" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json new file mode 100644 index 00000000..e69ee25d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + } +] diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py new file mode 100644 index 00000000..94be7de5 --- /dev/null +++ b/integration-tests/models/test_flash_gemma.py @@ -0,0 +1,58 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_handle(launcher): + with launcher("gg-hf/gemma-2b", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma(flash_gemma_handle): + await flash_gemma_handle.health(300) + return flash_gemma_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_all_params(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): + responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index bb55f5d5..4a08bc2a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -68,8 +68,6 @@ answers should not include any harmful, unethical, racist, sexist, toxic, danger If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ correct. If you don't know the answer to a question, please don't share false information.""" - - # fmt: on @@ -131,8 +129,8 @@ class GemmaTokenizerFast(PreTrainedTokenizerFast): if eos is None and self.add_eos_token: raise ValueError("add_eos_token = True but eos_token = None") - single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" - pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" special_tokens = [] if self.add_bos_token: @@ -185,12 +183,10 @@ class GemmaTokenizerFast(PreTrainedTokenizerFast): return (out_vocab_file,) @property - # Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.default_chat_template def default_chat_template(self): raise NotImplementedError # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers - # Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): bos_token_id = [self.bos_token_id] if self.add_bos_token else [] eos_token_id = [self.eos_token_id] if self.add_eos_token else []