diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2610d1ca..c2aba160 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -217,7 +217,7 @@ jobs: run: | export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - make integration-tests + pytest -s -vv integration-tests stop-runner: name: Stop self-hosted EC2 runner diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 521c9a0a..e9c51c37 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -7,6 +7,7 @@ import docker from docker.errors import NotFound from typing import Optional, List +from syrupy.filters import props from text_generation import AsyncClient from text_generation.types import Response @@ -16,6 +17,11 @@ HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") +@pytest.fixture +def snapshot_test(snapshot): + return lambda value: value == snapshot(exclude=props("logprob")) + + @pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() @@ -135,6 +141,6 @@ def generate_load(): ] results = await asyncio.gather(*futures) - return [r.generated_text for r in results] + return [r.dict() for r in results] return generate_load_inner diff --git a/integration-tests/models/__snapshots__/test_bloom_560m.ambr b/integration-tests/models/__snapshots__/test_bloom_560m.ambr index 9aa212e8..1067513d 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m.ambr +++ b/integration-tests/models/__snapshots__/test_bloom_560m.ambr @@ -8,57 +8,46 @@ 'prefill': list([ dict({ 'id': 17934, - 'logprob': None, 'text': 'Pour', }), dict({ 'id': 49833, - 'logprob': -10.5625, 'text': ' dég', }), dict({ 'id': 21543, - 'logprob': -0.14770508, 'text': 'uster', }), dict({ 'id': 447, - 'logprob': -1.9287109, 'text': ' un', }), dict({ 'id': 46341, - 'logprob': -15.4609375, 'text': ' ort', }), dict({ 'id': 35567, - 'logprob': -7.5585938, 'text': 'olan', }), dict({ 'id': 15, - 'logprob': -1.4003906, 'text': ',', }), dict({ 'id': 1669, - 'logprob': -1.5673828, 'text': ' il', }), dict({ 'id': 11580, - 'logprob': -0.94628906, 'text': ' faut', }), dict({ 'id': 3913, - 'logprob': -3.703125, 'text': ' tout', }), dict({ 'id': 39261, - 'logprob': -1.5732422, 'text': " d'abord", }), ]), @@ -66,61 +55,51 @@ 'tokens': list([ dict({ 'id': 578, - 'logprob': -1.6591797, 'special': False, 'text': ' le', }), dict({ 'id': 5608, - 'logprob': -2.4492188, 'special': False, 'text': ' faire', }), dict({ 'id': 159570, - 'logprob': -6.6835938, 'special': False, 'text': ' réch', }), dict({ 'id': 810, - 'logprob': 0.0, 'special': False, 'text': 'au', }), dict({ 'id': 12736, - 'logprob': 0.0, 'special': False, 'text': 'ffer', }), dict({ 'id': 1742, - 'logprob': -2.5175781, 'special': False, 'text': ' au', }), dict({ 'id': 6105, - 'logprob': -2.0078125, 'special': False, 'text': ' bain', }), dict({ 'id': 88254, - 'logprob': -0.12695312, 'special': False, 'text': '-mar', }), dict({ 'id': 641, - 'logprob': 0.0, 'special': False, 'text': 'ie', }), dict({ 'id': 2940, - 'logprob': -3.5175781, 'special': False, 'text': ' avec', }), @@ -138,27 +117,22 @@ 'prefill': list([ dict({ 'id': 15, - 'logprob': None, 'text': ',', }), dict({ 'id': 1669, - 'logprob': -5.4414062, 'text': ' il', }), dict({ 'id': 11580, - 'logprob': -2.3378906, 'text': ' faut', }), dict({ 'id': 3913, - 'logprob': -4.3554688, 'text': ' tout', }), dict({ 'id': 39261, - 'logprob': -2.9238281, 'text': " d'abord", }), ]), @@ -166,61 +140,51 @@ 'tokens': list([ dict({ 'id': 408, - 'logprob': -1.9267578, 'special': False, 'text': ' que', }), dict({ 'id': 20288, - 'logprob': -2.9257812, 'special': False, 'text': " l'on", }), dict({ 'id': 22255, - 'logprob': -2.8964844, 'special': False, 'text': ' trouve', }), dict({ 'id': 1622, - 'logprob': -1.1083984, 'special': False, 'text': ' une', }), dict({ 'id': 187079, - 'logprob': -7.796875, 'special': False, 'text': ' posture', }), dict({ 'id': 501, - 'logprob': -5.390625, 'special': False, 'text': ' par', }), dict({ 'id': 8741, - 'logprob': -0.34936523, 'special': False, 'text': ' rapport', }), dict({ 'id': 693, - 'logprob': 0.0, 'special': False, 'text': ' à', }), dict({ 'id': 366, - 'logprob': -2.3378906, 'special': False, 'text': ' la', }), dict({ 'id': 36503, - 'logprob': -3.6640625, 'special': False, 'text': ' pratique', }), @@ -231,9 +195,433 @@ # --- # name: test_bloom_560m_load list([ - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr index 1c842ddc..667a0373 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr +++ b/integration-tests/models/__snapshots__/test_bloom_560m_sharded.ambr @@ -8,57 +8,46 @@ 'prefill': list([ dict({ 'id': 17934, - 'logprob': None, 'text': 'Pour', }), dict({ 'id': 49833, - 'logprob': -10.5390625, 'text': ' dég', }), dict({ 'id': 21543, - 'logprob': -0.14758301, 'text': 'uster', }), dict({ 'id': 447, - 'logprob': -1.9296875, 'text': ' un', }), dict({ 'id': 46341, - 'logprob': -15.4453125, 'text': ' ort', }), dict({ 'id': 35567, - 'logprob': -7.59375, 'text': 'olan', }), dict({ 'id': 15, - 'logprob': -1.3994141, 'text': ',', }), dict({ 'id': 1669, - 'logprob': -1.578125, 'text': ' il', }), dict({ 'id': 11580, - 'logprob': -0.9453125, 'text': ' faut', }), dict({ 'id': 3913, - 'logprob': -3.7011719, 'text': ' tout', }), dict({ 'id': 39261, - 'logprob': -1.5732422, 'text': " d'abord", }), ]), @@ -66,61 +55,51 @@ 'tokens': list([ dict({ 'id': 578, - 'logprob': -1.6474609, 'special': False, 'text': ' le', }), dict({ 'id': 5608, - 'logprob': -2.5097656, 'special': False, 'text': ' faire', }), dict({ 'id': 159570, - 'logprob': -6.65625, 'special': False, 'text': ' réch', }), dict({ 'id': 810, - 'logprob': 0.0, 'special': False, 'text': 'au', }), dict({ 'id': 12736, - 'logprob': 0.0, 'special': False, 'text': 'ffer', }), dict({ 'id': 1742, - 'logprob': -2.5859375, 'special': False, 'text': ' au', }), dict({ 'id': 6105, - 'logprob': -2.03125, 'special': False, 'text': ' bain', }), dict({ 'id': 88254, - 'logprob': -0.12695312, 'special': False, 'text': '-mar', }), dict({ 'id': 641, - 'logprob': 0.0, 'special': False, 'text': 'ie', }), dict({ 'id': 2940, - 'logprob': -3.5175781, 'special': False, 'text': ' avec', }), @@ -131,9 +110,433 @@ # --- # name: test_bloom_560m_sharded_load list([ - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", - " le faire cuire dans de l'eau bouillante sal", + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 17934, + 'text': 'Pour', + }), + dict({ + 'id': 49833, + 'text': ' dég', + }), + dict({ + 'id': 21543, + 'text': 'uster', + }), + dict({ + 'id': 447, + 'text': ' un', + }), + dict({ + 'id': 46341, + 'text': ' ort', + }), + dict({ + 'id': 35567, + 'text': 'olan', + }), + dict({ + 'id': 15, + 'text': ',', + }), + dict({ + 'id': 1669, + 'text': ' il', + }), + dict({ + 'id': 11580, + 'text': ' faut', + }), + dict({ + 'id': 3913, + 'text': ' tout', + }), + dict({ + 'id': 39261, + 'text': " d'abord", + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 578, + 'special': False, + 'text': ' le', + }), + dict({ + 'id': 5608, + 'special': False, + 'text': ' faire', + }), + dict({ + 'id': 1767, + 'special': False, + 'text': ' cu', + }), + dict({ + 'id': 1273, + 'special': False, + 'text': 'ire', + }), + dict({ + 'id': 1486, + 'special': False, + 'text': ' dans', + }), + dict({ + 'id': 283, + 'special': False, + 'text': ' de', + }), + dict({ + 'id': 40410, + 'special': False, + 'text': " l'eau", + }), + dict({ + 'id': 20226, + 'special': False, + 'text': ' bou', + }), + dict({ + 'id': 172483, + 'special': False, + 'text': 'illante', + }), + dict({ + 'id': 2805, + 'special': False, + 'text': ' sal', + }), + ]), + }), + 'generated_text': " le faire cuire dans de l'eau bouillante sal", + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_llama.ambr b/integration-tests/models/__snapshots__/test_flash_llama.ambr index 2fde1b01..f4e3a4c1 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama.ambr +++ b/integration-tests/models/__snapshots__/test_flash_llama.ambr @@ -8,17 +8,14 @@ 'prefill': list([ dict({ 'id': 1, - 'logprob': None, 'text': '', }), dict({ 'id': 4321, - 'logprob': -8.6875, 'text': 'Test', }), dict({ 'id': 2009, - 'logprob': -11.5546875, 'text': 'request', }), ]), @@ -26,61 +23,51 @@ 'tokens': list([ dict({ 'id': 363, - 'logprob': -1.5380859, 'special': False, 'text': ' for', }), dict({ 'id': 847, - 'logprob': -2.5917969, 'special': False, 'text': ' /', }), dict({ 'id': 2754, - 'logprob': -2.2773438, 'special': False, 'text': 'api', }), dict({ 'id': 29914, - 'logprob': -0.034362793, 'special': False, 'text': '/', }), dict({ 'id': 29894, - 'logprob': -0.96533203, 'special': False, 'text': 'v', }), dict({ 'id': 29896, - 'logprob': -0.36669922, 'special': False, 'text': '1', }), dict({ 'id': 29914, - 'logprob': -0.013122559, 'special': False, 'text': '/', }), dict({ 'id': 16418, - 'logprob': -3.1503906, 'special': False, 'text': 'projects', }), dict({ 'id': 29914, - 'logprob': -0.43652344, 'special': False, 'text': '/', }), dict({ 'id': 29896, - 'logprob': -1.9404297, 'special': False, 'text': '1', }), @@ -98,17 +85,14 @@ 'prefill': list([ dict({ 'id': 1, - 'logprob': None, 'text': '', }), dict({ 'id': 4321, - 'logprob': -8.6875, 'text': 'Test', }), dict({ 'id': 2009, - 'logprob': -11.5546875, 'text': 'request', }), ]), @@ -116,55 +100,46 @@ 'tokens': list([ dict({ 'id': 5229, - 'logprob': -3.3085938, 'special': False, 'text': ' failed', }), dict({ 'id': 363, - 'logprob': -3.984375, 'special': False, 'text': ' for', }), dict({ 'id': 5641, - 'logprob': -6.53125, 'special': False, 'text': ' IP', }), dict({ 'id': 16428, - 'logprob': -3.1835938, 'special': False, 'text': ' Address', }), dict({ 'id': 29901, - 'logprob': -1.2324219, 'special': False, 'text': ':', }), dict({ 'id': 525, - 'logprob': -2.6855469, 'special': False, 'text': " '", }), dict({ 'id': 8516, - 'logprob': -7.1601562, 'special': False, 'text': 'None', }), dict({ 'id': 4286, - 'logprob': -2.4433594, 'special': False, 'text': "'.", }), dict({ 'id': 13, - 'logprob': -0.06530762, 'special': False, 'text': ''' @@ -173,7 +148,6 @@ }), dict({ 'id': 294, - 'logprob': -7.953125, 'special': False, 'text': 'as', }), @@ -187,9 +161,305 @@ # --- # name: test_flash_llama_load list([ - 'for /api/v1/projects/1', - 'for /api/v1/projects/1', - 'for /api/v1/projects/1', - 'for /api/v1/projects/1', + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1, + 'text': '', + }), + dict({ + 'id': 4321, + 'text': 'Test', + }), + dict({ + 'id': 2009, + 'text': 'request', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 363, + 'special': False, + 'text': ' for', + }), + dict({ + 'id': 847, + 'special': False, + 'text': ' /', + }), + dict({ + 'id': 2754, + 'special': False, + 'text': 'api', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29894, + 'special': False, + 'text': 'v', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 16418, + 'special': False, + 'text': 'projects', + }), + dict({ + 'id': 29914, + 'special': False, + 'text': '/', + }), + dict({ + 'id': 29896, + 'special': False, + 'text': '1', + }), + ]), + }), + 'generated_text': 'for /api/v1/projects/1', + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_neox.ambr b/integration-tests/models/__snapshots__/test_flash_neox.ambr index 671a1c0c..4330db6b 100644 --- a/integration-tests/models/__snapshots__/test_flash_neox.ambr +++ b/integration-tests/models/__snapshots__/test_flash_neox.ambr @@ -8,92 +8,74 @@ 'prefill': list([ dict({ 'id': 50278, - 'logprob': None, 'text': '<|prompter|>', }), dict({ 'id': 1276, - 'logprob': -8.03125, 'text': 'What', }), dict({ 'id': 310, - 'logprob': -5.421875, 'text': ' is', }), dict({ 'id': 247, - 'logprob': -2.1601562, 'text': ' a', }), dict({ 'id': 1167, - 'logprob': -5.4609375, 'text': ' mem', }), dict({ 'id': 70, - 'logprob': -0.005657196, 'text': 'e', }), dict({ 'id': 13, - 'logprob': -7.28125, 'text': ',', }), dict({ 'id': 285, - 'logprob': -0.2980957, 'text': ' and', }), dict({ 'id': 752, - 'logprob': -2.1679688, 'text': ' what', }), dict({ 'id': 434, - 'logprob': -5.6210938, 'text': "'s", }), dict({ 'id': 253, - 'logprob': -0.81103516, 'text': ' the', }), dict({ 'id': 2892, - 'logprob': -6.6640625, 'text': ' history', }), dict({ 'id': 3212, - 'logprob': -2.265625, 'text': ' behind', }), dict({ 'id': 436, - 'logprob': -11.5078125, 'text': ' this', }), dict({ 'id': 3159, - 'logprob': -2.1582031, 'text': ' word', }), dict({ 'id': 32, - 'logprob': -0.008720398, 'text': '?', }), dict({ 'id': 0, - 'logprob': -2.4726562, 'text': '<|endoftext|>', }), dict({ 'id': 50281, - 'logprob': -18.265625, 'text': '<|assistant|>', }), ]), @@ -101,61 +83,51 @@ 'tokens': list([ dict({ 'id': 510, - 'logprob': -0.63183594, 'special': False, 'text': 'The', }), dict({ 'id': 3159, - 'logprob': -0.5390625, 'special': False, 'text': ' word', }), dict({ 'id': 346, - 'logprob': -0.045684814, 'special': False, 'text': ' "', }), dict({ 'id': 6441, - 'logprob': -0.002090454, 'special': False, 'text': 'mem', }), dict({ 'id': 70, - 'logprob': -1.3589859e-05, 'special': False, 'text': 'e', }), dict({ 'id': 3, - 'logprob': -0.0009455681, 'special': False, 'text': '"', }), dict({ 'id': 369, - 'logprob': -0.088012695, 'special': False, 'text': ' was', }), dict({ 'id': 806, - 'logprob': -0.12585449, 'special': False, 'text': ' first', }), dict({ 'id': 908, - 'logprob': -0.017196655, 'special': False, 'text': ' used', }), dict({ 'id': 275, - 'logprob': -0.49731445, 'special': False, 'text': ' in', }), @@ -166,9 +138,545 @@ # --- # name: test_flash_neox_load list([ - 'The word "meme" was first used in', - 'The word "meme" was first used in', - 'The word "meme" was first used in', - 'The word "meme" was first used in', + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 50278, + 'text': '<|prompter|>', + }), + dict({ + 'id': 1276, + 'text': 'What', + }), + dict({ + 'id': 310, + 'text': ' is', + }), + dict({ + 'id': 247, + 'text': ' a', + }), + dict({ + 'id': 1167, + 'text': ' mem', + }), + dict({ + 'id': 70, + 'text': 'e', + }), + dict({ + 'id': 13, + 'text': ',', + }), + dict({ + 'id': 285, + 'text': ' and', + }), + dict({ + 'id': 752, + 'text': ' what', + }), + dict({ + 'id': 434, + 'text': "'s", + }), + dict({ + 'id': 253, + 'text': ' the', + }), + dict({ + 'id': 2892, + 'text': ' history', + }), + dict({ + 'id': 3212, + 'text': ' behind', + }), + dict({ + 'id': 436, + 'text': ' this', + }), + dict({ + 'id': 3159, + 'text': ' word', + }), + dict({ + 'id': 32, + 'text': '?', + }), + dict({ + 'id': 0, + 'text': '<|endoftext|>', + }), + dict({ + 'id': 50281, + 'text': '<|assistant|>', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 510, + 'special': False, + 'text': 'The', + }), + dict({ + 'id': 3159, + 'special': False, + 'text': ' word', + }), + dict({ + 'id': 346, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 6441, + 'special': False, + 'text': 'mem', + }), + dict({ + 'id': 70, + 'special': False, + 'text': 'e', + }), + dict({ + 'id': 3, + 'special': False, + 'text': '"', + }), + dict({ + 'id': 369, + 'special': False, + 'text': ' was', + }), + dict({ + 'id': 806, + 'special': False, + 'text': ' first', + }), + dict({ + 'id': 908, + 'special': False, + 'text': ' used', + }), + dict({ + 'id': 275, + 'special': False, + 'text': ' in', + }), + ]), + }), + 'generated_text': 'The word "meme" was first used in', + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr index a6a8e599..030820cb 100644 --- a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr +++ b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr @@ -8,22 +8,18 @@ 'prefill': list([ dict({ 'id': 563, - 'logprob': None, 'text': 'def', }), dict({ 'id': 942, - 'logprob': -5.1367188, 'text': ' print', }), dict({ 'id': 62, - 'logprob': -0.24450684, 'text': '_', }), dict({ 'id': 7196, - 'logprob': -6.9609375, 'text': 'hello', }), ]), @@ -31,13 +27,11 @@ 'tokens': list([ dict({ 'id': 1241, - 'logprob': -0.9863281, 'special': False, 'text': '():', }), dict({ 'id': 258, - 'logprob': -0.21447754, 'special': False, 'text': ''' @@ -46,37 +40,31 @@ }), dict({ 'id': 942, - 'logprob': -0.43701172, 'special': False, 'text': ' print', }), dict({ 'id': 372, - 'logprob': -0.5361328, 'special': False, 'text': '("', }), dict({ 'id': 7371, - 'logprob': -0.44555664, 'special': False, 'text': 'Hello', }), dict({ 'id': 9956, - 'logprob': -1.2412109, 'special': False, 'text': ' World', }), dict({ 'id': 8657, - 'logprob': -0.7583008, 'special': False, 'text': '!")', }), dict({ 'id': 185, - 'logprob': -0.76171875, 'special': False, 'text': ''' @@ -85,7 +73,6 @@ }), dict({ 'id': 185, - 'logprob': -0.20837402, 'special': False, 'text': ''' @@ -94,7 +81,6 @@ }), dict({ 'id': 1018, - 'logprob': -1.2470703, 'special': False, 'text': 'print', }), @@ -110,29 +96,377 @@ # --- # name: test_flash_santacoder_load list([ - ''' - (): - print("Hello World!") - - print - ''', - ''' - (): - print("Hello World!") - - print - ''', - ''' - (): - print("Hello World!") - - print - ''', - ''' - (): - print("Hello World!") - - print - ''', + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 563, + 'text': 'def', + }), + dict({ + 'id': 942, + 'text': ' print', + }), + dict({ + 'id': 62, + 'text': '_', + }), + dict({ + 'id': 7196, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 1241, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 258, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 942, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 372, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 7371, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 9956, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 8657, + 'special': False, + 'text': '!")', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 185, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1018, + 'special': False, + 'text': 'print', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World!") + + print + ''', + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr index 65e14581..e0f4b568 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr +++ b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr @@ -8,22 +8,18 @@ 'prefill': list([ dict({ 'id': 589, - 'logprob': None, 'text': 'def', }), dict({ 'id': 1459, - 'logprob': -5.6289062, 'text': ' print', }), dict({ 'id': 81, - 'logprob': -1.6005859, 'text': '_', }), dict({ 'id': 7656, - 'logprob': -5.9921875, 'text': 'hello', }), ]), @@ -31,13 +27,11 @@ 'tokens': list([ dict({ 'id': 2262, - 'logprob': -0.7705078, 'special': False, 'text': '():', }), dict({ 'id': 284, - 'logprob': -0.2590332, 'special': False, 'text': ''' @@ -46,37 +40,31 @@ }), dict({ 'id': 1459, - 'logprob': -0.39379883, 'special': False, 'text': ' print', }), dict({ 'id': 440, - 'logprob': -0.61376953, 'special': False, 'text': '("', }), dict({ 'id': 8279, - 'logprob': -0.47338867, 'special': False, 'text': 'Hello', }), dict({ 'id': 10896, - 'logprob': -1.5068359, 'special': False, 'text': ' World', }), dict({ 'id': 657, - 'logprob': -0.80810547, 'special': False, 'text': '")', }), dict({ 'id': 203, - 'logprob': -0.7397461, 'special': False, 'text': ''' @@ -85,7 +73,6 @@ }), dict({ 'id': 203, - 'logprob': -0.35229492, 'special': False, 'text': ''' @@ -94,7 +81,6 @@ }), dict({ 'id': 589, - 'logprob': -1.0371094, 'special': False, 'text': 'def', }), @@ -117,22 +103,18 @@ 'prefill': list([ dict({ 'id': 589, - 'logprob': None, 'text': 'def', }), dict({ 'id': 1459, - 'logprob': -5.6289062, 'text': ' print', }), dict({ 'id': 81, - 'logprob': -1.6005859, 'text': '_', }), dict({ 'id': 7656, - 'logprob': -5.9921875, 'text': 'hello', }), ]), @@ -140,13 +122,11 @@ 'tokens': list([ dict({ 'id': 2262, - 'logprob': -0.7451172, 'special': False, 'text': '():', }), dict({ 'id': 284, - 'logprob': -0.21325684, 'special': False, 'text': ''' @@ -155,55 +135,46 @@ }), dict({ 'id': 5741, - 'logprob': -5.734375, 'special': False, 'text': ' logging', }), dict({ 'id': 32, - 'logprob': 0.0, 'special': False, 'text': '.', }), dict({ 'id': 1338, - 'logprob': -0.3232422, 'special': False, 'text': 'info', }), dict({ 'id': 463, - 'logprob': -1.0380859, 'special': False, 'text': "('", }), dict({ 'id': 8279, - 'logprob': -0.8378906, 'special': False, 'text': 'Hello', }), dict({ 'id': 30, - 'logprob': -1.9501953, 'special': False, 'text': ',', }), dict({ 'id': 10896, - 'logprob': -1.3476562, 'special': False, 'text': ' World', }), dict({ 'id': 683, - 'logprob': -1.796875, 'special': False, 'text': "')", }), dict({ 'id': 203, - 'logprob': -0.9873047, 'special': False, 'text': ''' @@ -212,7 +183,6 @@ }), dict({ 'id': 0, - 'logprob': -0.7495117, 'special': True, 'text': '<|endoftext|>', }), @@ -227,29 +197,377 @@ # --- # name: test_flash_starcoder_load list([ - ''' - (): - print("Hello World") - - def - ''', - ''' - (): - print("Hello World") - - def - ''', - ''' - (): - print("Hello World") - - def - ''', - ''' - (): - print("Hello World") - - def - ''', + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 589, + 'text': 'def', + }), + dict({ + 'id': 1459, + 'text': ' print', + }), + dict({ + 'id': 81, + 'text': '_', + }), + dict({ + 'id': 7656, + 'text': 'hello', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 2262, + 'special': False, + 'text': '():', + }), + dict({ + 'id': 284, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 1459, + 'special': False, + 'text': ' print', + }), + dict({ + 'id': 440, + 'special': False, + 'text': '("', + }), + dict({ + 'id': 8279, + 'special': False, + 'text': 'Hello', + }), + dict({ + 'id': 10896, + 'special': False, + 'text': ' World', + }), + dict({ + 'id': 657, + 'special': False, + 'text': '")', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 203, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 589, + 'special': False, + 'text': 'def', + }), + ]), + }), + 'generated_text': ''' + (): + print("Hello World") + + def + ''', + }), ]) # --- diff --git a/integration-tests/models/__snapshots__/test_mt0_base.ambr b/integration-tests/models/__snapshots__/test_mt0_base.ambr index dc974891..d7c6eaf6 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base.ambr +++ b/integration-tests/models/__snapshots__/test_mt0_base.ambr @@ -8,7 +8,6 @@ 'prefill': list([ dict({ 'id': 0, - 'logprob': None, 'text': '', }), ]), @@ -16,31 +15,26 @@ 'tokens': list([ dict({ 'id': 926, - 'logprob': -4.3554688, 'special': False, 'text': 'To', }), dict({ 'id': 18295, - 'logprob': -7.7734375, 'special': False, 'text': ' sell', }), dict({ 'id': 7868, - 'logprob': -3.9257812, 'special': False, 'text': ' things', }), dict({ 'id': 260, - 'logprob': -2.4179688, 'special': False, 'text': '.', }), dict({ 'id': 1, - 'logprob': 0.0, 'special': True, 'text': '', }), @@ -58,7 +52,6 @@ 'prefill': list([ dict({ 'id': 0, - 'logprob': None, 'text': '', }), ]), @@ -66,61 +59,51 @@ 'tokens': list([ dict({ 'id': 16017, - 'logprob': -1.3505859, 'special': False, 'text': 'blue', }), dict({ 'id': 20495, - 'logprob': -0.50439453, 'special': False, 'text': ' sky', }), dict({ 'id': 259, - 'logprob': -1.2011719, 'special': False, 'text': ' ', }), dict({ 'id': 15484, - 'logprob': -2.8378906, 'special': False, 'text': 'appear', }), dict({ 'id': 345, - 'logprob': -0.87597656, 'special': False, 'text': 'ed', }), dict({ 'id': 288, - 'logprob': -1.8447266, 'special': False, 'text': ' to', }), dict({ 'id': 35622, - 'logprob': -7.1445312, 'special': False, 'text': ' cloud', }), dict({ 'id': 263, - 'logprob': -1.2929688, 'special': False, 'text': 's', }), dict({ 'id': 14701, - 'logprob': -3.0761719, 'special': False, 'text': ' above', }), dict({ 'id': 751, - 'logprob': -4.4375, 'special': False, 'text': ' all', }), @@ -131,9 +114,193 @@ # --- # name: test_mt0_base_load list([ - 'Because it is blue', - 'Because it is blue', - 'Because it is blue', - 'Because it is blue', + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 6, + 'prefill': list([ + dict({ + 'id': 0, + 'text': '', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 259, + 'special': False, + 'text': '', + }), + dict({ + 'id': 39261, + 'special': False, + 'text': 'Because', + }), + dict({ + 'id': 609, + 'special': False, + 'text': ' it', + }), + dict({ + 'id': 339, + 'special': False, + 'text': ' is', + }), + dict({ + 'id': 16017, + 'special': False, + 'text': ' blue', + }), + dict({ + 'id': 1, + 'special': True, + 'text': '', + }), + ]), + }), + 'generated_text': 'Because it is blue', + }), ]) # --- diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index 39850cad..e13606f7 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -10,7 +10,7 @@ def bloom_560(launcher): @pytest.mark.asyncio -async def test_bloom_560m(bloom_560, snapshot): +async def test_bloom_560m(bloom_560, snapshot_test): await health_check(bloom_560, 60) response = await bloom_560.generate( @@ -21,11 +21,11 @@ async def test_bloom_560m(bloom_560, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_bloom_560m_all_params(bloom_560, snapshot): +async def test_bloom_560m_all_params(bloom_560, snapshot_test): await health_check(bloom_560, 60) response = await bloom_560.generate( @@ -44,11 +44,11 @@ async def test_bloom_560m_all_params(bloom_560, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_bloom_560m_load(bloom_560, generate_load, snapshot): +async def test_bloom_560m_load(bloom_560, generate_load, snapshot_test): await health_check(bloom_560, 60) responses = await generate_load( @@ -60,4 +60,4 @@ async def test_bloom_560m_load(bloom_560, generate_load, snapshot): assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_bloom_560m_sharded.py b/integration-tests/models/test_bloom_560m_sharded.py index 89d95a23..bfb70253 100644 --- a/integration-tests/models/test_bloom_560m_sharded.py +++ b/integration-tests/models/test_bloom_560m_sharded.py @@ -10,7 +10,7 @@ def bloom_560m_sharded(launcher): @pytest.mark.asyncio -async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): +async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot_test): await health_check(bloom_560m_sharded, 60) response = await bloom_560m_sharded.generate( @@ -21,11 +21,13 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot): +async def test_bloom_560m_sharded_load( + bloom_560m_sharded, generate_load, snapshot_test +): await health_check(bloom_560m_sharded, 60) responses = await generate_load( @@ -37,4 +39,4 @@ async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapsh assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index e1e23cd7..4d1f2bcf 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -11,18 +11,18 @@ def flash_llama(launcher): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama(flash_llama, snapshot): +async def test_flash_llama(flash_llama, snapshot_test): await health_check(flash_llama, 120) response = await flash_llama.generate("Test request", max_new_tokens=10) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_all_params(flash_llama, snapshot): +async def test_flash_llama_all_params(flash_llama, snapshot_test): await health_check(flash_llama, 120) response = await flash_llama.generate( @@ -41,16 +41,16 @@ async def test_flash_llama_all_params(flash_llama, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_load(flash_llama, generate_load, snapshot): +async def test_flash_llama_load(flash_llama, generate_load, snapshot_test): await health_check(flash_llama, 120) responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index 42d8182a..8c981028 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -10,7 +10,7 @@ def flash_neox(launcher): @pytest.mark.asyncio -async def test_flash_neox(flash_neox, snapshot): +async def test_flash_neox(flash_neox, snapshot_test): await health_check(flash_neox, 240) response = await flash_neox.generate( @@ -19,11 +19,11 @@ async def test_flash_neox(flash_neox, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_flash_neox_load(flash_neox, generate_load, snapshot): +async def test_flash_neox_load(flash_neox, generate_load, snapshot_test): await health_check(flash_neox, 240) responses = await generate_load( @@ -35,4 +35,4 @@ async def test_flash_neox_load(flash_neox, generate_load, snapshot): assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py index 8ee44839..64a59d78 100644 --- a/integration-tests/models/test_flash_santacoder.py +++ b/integration-tests/models/test_flash_santacoder.py @@ -10,17 +10,17 @@ def flash_santacoder(launcher): @pytest.mark.asyncio -async def test_flash_santacoder(flash_santacoder, snapshot): +async def test_flash_santacoder(flash_santacoder, snapshot_test): await health_check(flash_santacoder, 60) response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot): +async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot_test): await health_check(flash_santacoder, 60) responses = await generate_load( @@ -29,4 +29,4 @@ async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot): assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 52e55296..d43e92dc 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -11,18 +11,18 @@ def flash_starcoder(launcher): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_starcoder(flash_starcoder, snapshot): +async def test_flash_starcoder(flash_starcoder, snapshot_test): await health_check(flash_starcoder, 240) response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio @pytest.mark.private -async def test_flash_starcoder_default_params(flash_starcoder, snapshot): +async def test_flash_starcoder_default_params(flash_starcoder, snapshot_test): await health_check(flash_starcoder, 240) response = await flash_starcoder.generate( @@ -30,12 +30,12 @@ async def test_flash_starcoder_default_params(flash_starcoder, snapshot): ) assert response.details.generated_tokens == 12 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio @pytest.mark.private -async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): +async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot_test): await health_check(flash_starcoder, 240) responses = await generate_load( @@ -44,4 +44,4 @@ async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses) diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 70ac470a..7310a30f 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -10,7 +10,7 @@ def mt0_base(launcher): @pytest.mark.asyncio -async def test_mt0_base(mt0_base, snapshot): +async def test_mt0_base(mt0_base, snapshot_test): await health_check(mt0_base, 60) response = await mt0_base.generate( @@ -21,11 +21,11 @@ async def test_mt0_base(mt0_base, snapshot): ) assert response.details.generated_tokens == 5 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_mt0_base_all_params(mt0_base, snapshot): +async def test_mt0_base_all_params(mt0_base, snapshot_test): await health_check(mt0_base, 60) response = await mt0_base.generate( @@ -44,11 +44,11 @@ async def test_mt0_base_all_params(mt0_base, snapshot): ) assert response.details.generated_tokens == 10 - assert response == snapshot + assert snapshot_test(response) @pytest.mark.asyncio -async def test_mt0_base_load(mt0_base, generate_load, snapshot): +async def test_mt0_base_load(mt0_base, generate_load, snapshot_test): await health_check(mt0_base, 60) responses = await generate_load( @@ -60,4 +60,4 @@ async def test_mt0_base_load(mt0_base, generate_load, snapshot): assert len(responses) == 4 - assert responses == snapshot + assert snapshot_test(responses)