remove logprobs

This commit is contained in:
OlivierDehaene 2023-05-15 23:06:57 +02:00
parent 9fcf03d13c
commit 8d0f8c2c30
16 changed files with 2674 additions and 278 deletions

View File

@ -217,7 +217,7 @@ jobs:
run: | run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
make integration-tests pytest -s -vv integration-tests
stop-runner: stop-runner:
name: Stop self-hosted EC2 runner name: Stop self-hosted EC2 runner

View File

@ -7,6 +7,7 @@ import docker
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List from typing import Optional, List
from syrupy.filters import props
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import Response from text_generation.types import Response
@ -16,6 +17,11 @@ HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
@pytest.fixture
def snapshot_test(snapshot):
return lambda value: value == snapshot(exclude=props("logprob"))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop(): def event_loop():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -135,6 +141,6 @@ def generate_load():
] ]
results = await asyncio.gather(*futures) results = await asyncio.gather(*futures)
return [r.generated_text for r in results] return [r.dict() for r in results]
return generate_load_inner return generate_load_inner

View File

@ -8,57 +8,46 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 17934, 'id': 17934,
'logprob': None,
'text': 'Pour', 'text': 'Pour',
}), }),
dict({ dict({
'id': 49833, 'id': 49833,
'logprob': -10.5625,
'text': ' dég', 'text': ' dég',
}), }),
dict({ dict({
'id': 21543, 'id': 21543,
'logprob': -0.14770508,
'text': 'uster', 'text': 'uster',
}), }),
dict({ dict({
'id': 447, 'id': 447,
'logprob': -1.9287109,
'text': ' un', 'text': ' un',
}), }),
dict({ dict({
'id': 46341, 'id': 46341,
'logprob': -15.4609375,
'text': ' ort', 'text': ' ort',
}), }),
dict({ dict({
'id': 35567, 'id': 35567,
'logprob': -7.5585938,
'text': 'olan', 'text': 'olan',
}), }),
dict({ dict({
'id': 15, 'id': 15,
'logprob': -1.4003906,
'text': ',', 'text': ',',
}), }),
dict({ dict({
'id': 1669, 'id': 1669,
'logprob': -1.5673828,
'text': ' il', 'text': ' il',
}), }),
dict({ dict({
'id': 11580, 'id': 11580,
'logprob': -0.94628906,
'text': ' faut', 'text': ' faut',
}), }),
dict({ dict({
'id': 3913, 'id': 3913,
'logprob': -3.703125,
'text': ' tout', 'text': ' tout',
}), }),
dict({ dict({
'id': 39261, 'id': 39261,
'logprob': -1.5732422,
'text': " d'abord", 'text': " d'abord",
}), }),
]), ]),
@ -66,61 +55,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 578, 'id': 578,
'logprob': -1.6591797,
'special': False, 'special': False,
'text': ' le', 'text': ' le',
}), }),
dict({ dict({
'id': 5608, 'id': 5608,
'logprob': -2.4492188,
'special': False, 'special': False,
'text': ' faire', 'text': ' faire',
}), }),
dict({ dict({
'id': 159570, 'id': 159570,
'logprob': -6.6835938,
'special': False, 'special': False,
'text': ' réch', 'text': ' réch',
}), }),
dict({ dict({
'id': 810, 'id': 810,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'au', 'text': 'au',
}), }),
dict({ dict({
'id': 12736, 'id': 12736,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'ffer', 'text': 'ffer',
}), }),
dict({ dict({
'id': 1742, 'id': 1742,
'logprob': -2.5175781,
'special': False, 'special': False,
'text': ' au', 'text': ' au',
}), }),
dict({ dict({
'id': 6105, 'id': 6105,
'logprob': -2.0078125,
'special': False, 'special': False,
'text': ' bain', 'text': ' bain',
}), }),
dict({ dict({
'id': 88254, 'id': 88254,
'logprob': -0.12695312,
'special': False, 'special': False,
'text': '-mar', 'text': '-mar',
}), }),
dict({ dict({
'id': 641, 'id': 641,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'ie', 'text': 'ie',
}), }),
dict({ dict({
'id': 2940, 'id': 2940,
'logprob': -3.5175781,
'special': False, 'special': False,
'text': ' avec', 'text': ' avec',
}), }),
@ -138,27 +117,22 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 15, 'id': 15,
'logprob': None,
'text': ',', 'text': ',',
}), }),
dict({ dict({
'id': 1669, 'id': 1669,
'logprob': -5.4414062,
'text': ' il', 'text': ' il',
}), }),
dict({ dict({
'id': 11580, 'id': 11580,
'logprob': -2.3378906,
'text': ' faut', 'text': ' faut',
}), }),
dict({ dict({
'id': 3913, 'id': 3913,
'logprob': -4.3554688,
'text': ' tout', 'text': ' tout',
}), }),
dict({ dict({
'id': 39261, 'id': 39261,
'logprob': -2.9238281,
'text': " d'abord", 'text': " d'abord",
}), }),
]), ]),
@ -166,61 +140,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 408, 'id': 408,
'logprob': -1.9267578,
'special': False, 'special': False,
'text': ' que', 'text': ' que',
}), }),
dict({ dict({
'id': 20288, 'id': 20288,
'logprob': -2.9257812,
'special': False, 'special': False,
'text': " l'on", 'text': " l'on",
}), }),
dict({ dict({
'id': 22255, 'id': 22255,
'logprob': -2.8964844,
'special': False, 'special': False,
'text': ' trouve', 'text': ' trouve',
}), }),
dict({ dict({
'id': 1622, 'id': 1622,
'logprob': -1.1083984,
'special': False, 'special': False,
'text': ' une', 'text': ' une',
}), }),
dict({ dict({
'id': 187079, 'id': 187079,
'logprob': -7.796875,
'special': False, 'special': False,
'text': ' posture', 'text': ' posture',
}), }),
dict({ dict({
'id': 501, 'id': 501,
'logprob': -5.390625,
'special': False, 'special': False,
'text': ' par', 'text': ' par',
}), }),
dict({ dict({
'id': 8741, 'id': 8741,
'logprob': -0.34936523,
'special': False, 'special': False,
'text': ' rapport', 'text': ' rapport',
}), }),
dict({ dict({
'id': 693, 'id': 693,
'logprob': 0.0,
'special': False, 'special': False,
'text': ' à', 'text': ' à',
}), }),
dict({ dict({
'id': 366, 'id': 366,
'logprob': -2.3378906,
'special': False, 'special': False,
'text': ' la', 'text': ' la',
}), }),
dict({ dict({
'id': 36503, 'id': 36503,
'logprob': -3.6640625,
'special': False, 'special': False,
'text': ' pratique', 'text': ' pratique',
}), }),
@ -231,9 +195,433 @@
# --- # ---
# name: test_bloom_560m_load # name: test_bloom_560m_load
list([ list([
" le faire cuire dans de l'eau bouillante sal", dict({
" le faire cuire dans de l'eau bouillante sal", 'details': dict({
" le faire cuire dans de l'eau bouillante sal", 'best_of_sequences': None,
" le faire cuire dans de l'eau bouillante sal", 'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
]) ])
# --- # ---

View File

@ -8,57 +8,46 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 17934, 'id': 17934,
'logprob': None,
'text': 'Pour', 'text': 'Pour',
}), }),
dict({ dict({
'id': 49833, 'id': 49833,
'logprob': -10.5390625,
'text': ' dég', 'text': ' dég',
}), }),
dict({ dict({
'id': 21543, 'id': 21543,
'logprob': -0.14758301,
'text': 'uster', 'text': 'uster',
}), }),
dict({ dict({
'id': 447, 'id': 447,
'logprob': -1.9296875,
'text': ' un', 'text': ' un',
}), }),
dict({ dict({
'id': 46341, 'id': 46341,
'logprob': -15.4453125,
'text': ' ort', 'text': ' ort',
}), }),
dict({ dict({
'id': 35567, 'id': 35567,
'logprob': -7.59375,
'text': 'olan', 'text': 'olan',
}), }),
dict({ dict({
'id': 15, 'id': 15,
'logprob': -1.3994141,
'text': ',', 'text': ',',
}), }),
dict({ dict({
'id': 1669, 'id': 1669,
'logprob': -1.578125,
'text': ' il', 'text': ' il',
}), }),
dict({ dict({
'id': 11580, 'id': 11580,
'logprob': -0.9453125,
'text': ' faut', 'text': ' faut',
}), }),
dict({ dict({
'id': 3913, 'id': 3913,
'logprob': -3.7011719,
'text': ' tout', 'text': ' tout',
}), }),
dict({ dict({
'id': 39261, 'id': 39261,
'logprob': -1.5732422,
'text': " d'abord", 'text': " d'abord",
}), }),
]), ]),
@ -66,61 +55,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 578, 'id': 578,
'logprob': -1.6474609,
'special': False, 'special': False,
'text': ' le', 'text': ' le',
}), }),
dict({ dict({
'id': 5608, 'id': 5608,
'logprob': -2.5097656,
'special': False, 'special': False,
'text': ' faire', 'text': ' faire',
}), }),
dict({ dict({
'id': 159570, 'id': 159570,
'logprob': -6.65625,
'special': False, 'special': False,
'text': ' réch', 'text': ' réch',
}), }),
dict({ dict({
'id': 810, 'id': 810,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'au', 'text': 'au',
}), }),
dict({ dict({
'id': 12736, 'id': 12736,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'ffer', 'text': 'ffer',
}), }),
dict({ dict({
'id': 1742, 'id': 1742,
'logprob': -2.5859375,
'special': False, 'special': False,
'text': ' au', 'text': ' au',
}), }),
dict({ dict({
'id': 6105, 'id': 6105,
'logprob': -2.03125,
'special': False, 'special': False,
'text': ' bain', 'text': ' bain',
}), }),
dict({ dict({
'id': 88254, 'id': 88254,
'logprob': -0.12695312,
'special': False, 'special': False,
'text': '-mar', 'text': '-mar',
}), }),
dict({ dict({
'id': 641, 'id': 641,
'logprob': 0.0,
'special': False, 'special': False,
'text': 'ie', 'text': 'ie',
}), }),
dict({ dict({
'id': 2940, 'id': 2940,
'logprob': -3.5175781,
'special': False, 'special': False,
'text': ' avec', 'text': ' avec',
}), }),
@ -131,9 +110,433 @@
# --- # ---
# name: test_bloom_560m_sharded_load # name: test_bloom_560m_sharded_load
list([ list([
" le faire cuire dans de l'eau bouillante sal", dict({
" le faire cuire dans de l'eau bouillante sal", 'details': dict({
" le faire cuire dans de l'eau bouillante sal", 'best_of_sequences': None,
" le faire cuire dans de l'eau bouillante sal", 'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 17934,
'text': 'Pour',
}),
dict({
'id': 49833,
'text': ' dég',
}),
dict({
'id': 21543,
'text': 'uster',
}),
dict({
'id': 447,
'text': ' un',
}),
dict({
'id': 46341,
'text': ' ort',
}),
dict({
'id': 35567,
'text': 'olan',
}),
dict({
'id': 15,
'text': ',',
}),
dict({
'id': 1669,
'text': ' il',
}),
dict({
'id': 11580,
'text': ' faut',
}),
dict({
'id': 3913,
'text': ' tout',
}),
dict({
'id': 39261,
'text': " d'abord",
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 578,
'special': False,
'text': ' le',
}),
dict({
'id': 5608,
'special': False,
'text': ' faire',
}),
dict({
'id': 1767,
'special': False,
'text': ' cu',
}),
dict({
'id': 1273,
'special': False,
'text': 'ire',
}),
dict({
'id': 1486,
'special': False,
'text': ' dans',
}),
dict({
'id': 283,
'special': False,
'text': ' de',
}),
dict({
'id': 40410,
'special': False,
'text': " l'eau",
}),
dict({
'id': 20226,
'special': False,
'text': ' bou',
}),
dict({
'id': 172483,
'special': False,
'text': 'illante',
}),
dict({
'id': 2805,
'special': False,
'text': ' sal',
}),
]),
}),
'generated_text': " le faire cuire dans de l'eau bouillante sal",
}),
]) ])
# --- # ---

View File

@ -8,17 +8,14 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 1, 'id': 1,
'logprob': None,
'text': '<s>', 'text': '<s>',
}), }),
dict({ dict({
'id': 4321, 'id': 4321,
'logprob': -8.6875,
'text': 'Test', 'text': 'Test',
}), }),
dict({ dict({
'id': 2009, 'id': 2009,
'logprob': -11.5546875,
'text': 'request', 'text': 'request',
}), }),
]), ]),
@ -26,61 +23,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 363, 'id': 363,
'logprob': -1.5380859,
'special': False, 'special': False,
'text': ' for', 'text': ' for',
}), }),
dict({ dict({
'id': 847, 'id': 847,
'logprob': -2.5917969,
'special': False, 'special': False,
'text': ' /', 'text': ' /',
}), }),
dict({ dict({
'id': 2754, 'id': 2754,
'logprob': -2.2773438,
'special': False, 'special': False,
'text': 'api', 'text': 'api',
}), }),
dict({ dict({
'id': 29914, 'id': 29914,
'logprob': -0.034362793,
'special': False, 'special': False,
'text': '/', 'text': '/',
}), }),
dict({ dict({
'id': 29894, 'id': 29894,
'logprob': -0.96533203,
'special': False, 'special': False,
'text': 'v', 'text': 'v',
}), }),
dict({ dict({
'id': 29896, 'id': 29896,
'logprob': -0.36669922,
'special': False, 'special': False,
'text': '1', 'text': '1',
}), }),
dict({ dict({
'id': 29914, 'id': 29914,
'logprob': -0.013122559,
'special': False, 'special': False,
'text': '/', 'text': '/',
}), }),
dict({ dict({
'id': 16418, 'id': 16418,
'logprob': -3.1503906,
'special': False, 'special': False,
'text': 'projects', 'text': 'projects',
}), }),
dict({ dict({
'id': 29914, 'id': 29914,
'logprob': -0.43652344,
'special': False, 'special': False,
'text': '/', 'text': '/',
}), }),
dict({ dict({
'id': 29896, 'id': 29896,
'logprob': -1.9404297,
'special': False, 'special': False,
'text': '1', 'text': '1',
}), }),
@ -98,17 +85,14 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 1, 'id': 1,
'logprob': None,
'text': '<s>', 'text': '<s>',
}), }),
dict({ dict({
'id': 4321, 'id': 4321,
'logprob': -8.6875,
'text': 'Test', 'text': 'Test',
}), }),
dict({ dict({
'id': 2009, 'id': 2009,
'logprob': -11.5546875,
'text': 'request', 'text': 'request',
}), }),
]), ]),
@ -116,55 +100,46 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 5229, 'id': 5229,
'logprob': -3.3085938,
'special': False, 'special': False,
'text': ' failed', 'text': ' failed',
}), }),
dict({ dict({
'id': 363, 'id': 363,
'logprob': -3.984375,
'special': False, 'special': False,
'text': ' for', 'text': ' for',
}), }),
dict({ dict({
'id': 5641, 'id': 5641,
'logprob': -6.53125,
'special': False, 'special': False,
'text': ' IP', 'text': ' IP',
}), }),
dict({ dict({
'id': 16428, 'id': 16428,
'logprob': -3.1835938,
'special': False, 'special': False,
'text': ' Address', 'text': ' Address',
}), }),
dict({ dict({
'id': 29901, 'id': 29901,
'logprob': -1.2324219,
'special': False, 'special': False,
'text': ':', 'text': ':',
}), }),
dict({ dict({
'id': 525, 'id': 525,
'logprob': -2.6855469,
'special': False, 'special': False,
'text': " '", 'text': " '",
}), }),
dict({ dict({
'id': 8516, 'id': 8516,
'logprob': -7.1601562,
'special': False, 'special': False,
'text': 'None', 'text': 'None',
}), }),
dict({ dict({
'id': 4286, 'id': 4286,
'logprob': -2.4433594,
'special': False, 'special': False,
'text': "'.", 'text': "'.",
}), }),
dict({ dict({
'id': 13, 'id': 13,
'logprob': -0.06530762,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -173,7 +148,6 @@
}), }),
dict({ dict({
'id': 294, 'id': 294,
'logprob': -7.953125,
'special': False, 'special': False,
'text': 'as', 'text': 'as',
}), }),
@ -187,9 +161,305 @@
# --- # ---
# name: test_flash_llama_load # name: test_flash_llama_load
list([ list([
'for /api/v1/projects/1', dict({
'for /api/v1/projects/1', 'details': dict({
'for /api/v1/projects/1', 'best_of_sequences': None,
'for /api/v1/projects/1', 'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 1,
'text': '<s>',
}),
dict({
'id': 4321,
'text': 'Test',
}),
dict({
'id': 2009,
'text': 'request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 363,
'special': False,
'text': ' for',
}),
dict({
'id': 847,
'special': False,
'text': ' /',
}),
dict({
'id': 2754,
'special': False,
'text': 'api',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29894,
'special': False,
'text': 'v',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 16418,
'special': False,
'text': 'projects',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
]),
}),
'generated_text': 'for /api/v1/projects/1',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 1,
'text': '<s>',
}),
dict({
'id': 4321,
'text': 'Test',
}),
dict({
'id': 2009,
'text': 'request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 363,
'special': False,
'text': ' for',
}),
dict({
'id': 847,
'special': False,
'text': ' /',
}),
dict({
'id': 2754,
'special': False,
'text': 'api',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29894,
'special': False,
'text': 'v',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 16418,
'special': False,
'text': 'projects',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
]),
}),
'generated_text': 'for /api/v1/projects/1',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 1,
'text': '<s>',
}),
dict({
'id': 4321,
'text': 'Test',
}),
dict({
'id': 2009,
'text': 'request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 363,
'special': False,
'text': ' for',
}),
dict({
'id': 847,
'special': False,
'text': ' /',
}),
dict({
'id': 2754,
'special': False,
'text': 'api',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29894,
'special': False,
'text': 'v',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 16418,
'special': False,
'text': 'projects',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
]),
}),
'generated_text': 'for /api/v1/projects/1',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 1,
'text': '<s>',
}),
dict({
'id': 4321,
'text': 'Test',
}),
dict({
'id': 2009,
'text': 'request',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 363,
'special': False,
'text': ' for',
}),
dict({
'id': 847,
'special': False,
'text': ' /',
}),
dict({
'id': 2754,
'special': False,
'text': 'api',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29894,
'special': False,
'text': 'v',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 16418,
'special': False,
'text': 'projects',
}),
dict({
'id': 29914,
'special': False,
'text': '/',
}),
dict({
'id': 29896,
'special': False,
'text': '1',
}),
]),
}),
'generated_text': 'for /api/v1/projects/1',
}),
]) ])
# --- # ---

View File

@ -8,92 +8,74 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 50278, 'id': 50278,
'logprob': None,
'text': '<|prompter|>', 'text': '<|prompter|>',
}), }),
dict({ dict({
'id': 1276, 'id': 1276,
'logprob': -8.03125,
'text': 'What', 'text': 'What',
}), }),
dict({ dict({
'id': 310, 'id': 310,
'logprob': -5.421875,
'text': ' is', 'text': ' is',
}), }),
dict({ dict({
'id': 247, 'id': 247,
'logprob': -2.1601562,
'text': ' a', 'text': ' a',
}), }),
dict({ dict({
'id': 1167, 'id': 1167,
'logprob': -5.4609375,
'text': ' mem', 'text': ' mem',
}), }),
dict({ dict({
'id': 70, 'id': 70,
'logprob': -0.005657196,
'text': 'e', 'text': 'e',
}), }),
dict({ dict({
'id': 13, 'id': 13,
'logprob': -7.28125,
'text': ',', 'text': ',',
}), }),
dict({ dict({
'id': 285, 'id': 285,
'logprob': -0.2980957,
'text': ' and', 'text': ' and',
}), }),
dict({ dict({
'id': 752, 'id': 752,
'logprob': -2.1679688,
'text': ' what', 'text': ' what',
}), }),
dict({ dict({
'id': 434, 'id': 434,
'logprob': -5.6210938,
'text': "'s", 'text': "'s",
}), }),
dict({ dict({
'id': 253, 'id': 253,
'logprob': -0.81103516,
'text': ' the', 'text': ' the',
}), }),
dict({ dict({
'id': 2892, 'id': 2892,
'logprob': -6.6640625,
'text': ' history', 'text': ' history',
}), }),
dict({ dict({
'id': 3212, 'id': 3212,
'logprob': -2.265625,
'text': ' behind', 'text': ' behind',
}), }),
dict({ dict({
'id': 436, 'id': 436,
'logprob': -11.5078125,
'text': ' this', 'text': ' this',
}), }),
dict({ dict({
'id': 3159, 'id': 3159,
'logprob': -2.1582031,
'text': ' word', 'text': ' word',
}), }),
dict({ dict({
'id': 32, 'id': 32,
'logprob': -0.008720398,
'text': '?', 'text': '?',
}), }),
dict({ dict({
'id': 0, 'id': 0,
'logprob': -2.4726562,
'text': '<|endoftext|>', 'text': '<|endoftext|>',
}), }),
dict({ dict({
'id': 50281, 'id': 50281,
'logprob': -18.265625,
'text': '<|assistant|>', 'text': '<|assistant|>',
}), }),
]), ]),
@ -101,61 +83,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 510, 'id': 510,
'logprob': -0.63183594,
'special': False, 'special': False,
'text': 'The', 'text': 'The',
}), }),
dict({ dict({
'id': 3159, 'id': 3159,
'logprob': -0.5390625,
'special': False, 'special': False,
'text': ' word', 'text': ' word',
}), }),
dict({ dict({
'id': 346, 'id': 346,
'logprob': -0.045684814,
'special': False, 'special': False,
'text': ' "', 'text': ' "',
}), }),
dict({ dict({
'id': 6441, 'id': 6441,
'logprob': -0.002090454,
'special': False, 'special': False,
'text': 'mem', 'text': 'mem',
}), }),
dict({ dict({
'id': 70, 'id': 70,
'logprob': -1.3589859e-05,
'special': False, 'special': False,
'text': 'e', 'text': 'e',
}), }),
dict({ dict({
'id': 3, 'id': 3,
'logprob': -0.0009455681,
'special': False, 'special': False,
'text': '"', 'text': '"',
}), }),
dict({ dict({
'id': 369, 'id': 369,
'logprob': -0.088012695,
'special': False, 'special': False,
'text': ' was', 'text': ' was',
}), }),
dict({ dict({
'id': 806, 'id': 806,
'logprob': -0.12585449,
'special': False, 'special': False,
'text': ' first', 'text': ' first',
}), }),
dict({ dict({
'id': 908, 'id': 908,
'logprob': -0.017196655,
'special': False, 'special': False,
'text': ' used', 'text': ' used',
}), }),
dict({ dict({
'id': 275, 'id': 275,
'logprob': -0.49731445,
'special': False, 'special': False,
'text': ' in', 'text': ' in',
}), }),
@ -166,9 +138,545 @@
# --- # ---
# name: test_flash_neox_load # name: test_flash_neox_load
list([ list([
'The word "meme" was first used in', dict({
'The word "meme" was first used in', 'details': dict({
'The word "meme" was first used in', 'best_of_sequences': None,
'The word "meme" was first used in', 'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 50278,
'text': '<|prompter|>',
}),
dict({
'id': 1276,
'text': 'What',
}),
dict({
'id': 310,
'text': ' is',
}),
dict({
'id': 247,
'text': ' a',
}),
dict({
'id': 1167,
'text': ' mem',
}),
dict({
'id': 70,
'text': 'e',
}),
dict({
'id': 13,
'text': ',',
}),
dict({
'id': 285,
'text': ' and',
}),
dict({
'id': 752,
'text': ' what',
}),
dict({
'id': 434,
'text': "'s",
}),
dict({
'id': 253,
'text': ' the',
}),
dict({
'id': 2892,
'text': ' history',
}),
dict({
'id': 3212,
'text': ' behind',
}),
dict({
'id': 436,
'text': ' this',
}),
dict({
'id': 3159,
'text': ' word',
}),
dict({
'id': 32,
'text': '?',
}),
dict({
'id': 0,
'text': '<|endoftext|>',
}),
dict({
'id': 50281,
'text': '<|assistant|>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 510,
'special': False,
'text': 'The',
}),
dict({
'id': 3159,
'special': False,
'text': ' word',
}),
dict({
'id': 346,
'special': False,
'text': ' "',
}),
dict({
'id': 6441,
'special': False,
'text': 'mem',
}),
dict({
'id': 70,
'special': False,
'text': 'e',
}),
dict({
'id': 3,
'special': False,
'text': '"',
}),
dict({
'id': 369,
'special': False,
'text': ' was',
}),
dict({
'id': 806,
'special': False,
'text': ' first',
}),
dict({
'id': 908,
'special': False,
'text': ' used',
}),
dict({
'id': 275,
'special': False,
'text': ' in',
}),
]),
}),
'generated_text': 'The word "meme" was first used in',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 50278,
'text': '<|prompter|>',
}),
dict({
'id': 1276,
'text': 'What',
}),
dict({
'id': 310,
'text': ' is',
}),
dict({
'id': 247,
'text': ' a',
}),
dict({
'id': 1167,
'text': ' mem',
}),
dict({
'id': 70,
'text': 'e',
}),
dict({
'id': 13,
'text': ',',
}),
dict({
'id': 285,
'text': ' and',
}),
dict({
'id': 752,
'text': ' what',
}),
dict({
'id': 434,
'text': "'s",
}),
dict({
'id': 253,
'text': ' the',
}),
dict({
'id': 2892,
'text': ' history',
}),
dict({
'id': 3212,
'text': ' behind',
}),
dict({
'id': 436,
'text': ' this',
}),
dict({
'id': 3159,
'text': ' word',
}),
dict({
'id': 32,
'text': '?',
}),
dict({
'id': 0,
'text': '<|endoftext|>',
}),
dict({
'id': 50281,
'text': '<|assistant|>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 510,
'special': False,
'text': 'The',
}),
dict({
'id': 3159,
'special': False,
'text': ' word',
}),
dict({
'id': 346,
'special': False,
'text': ' "',
}),
dict({
'id': 6441,
'special': False,
'text': 'mem',
}),
dict({
'id': 70,
'special': False,
'text': 'e',
}),
dict({
'id': 3,
'special': False,
'text': '"',
}),
dict({
'id': 369,
'special': False,
'text': ' was',
}),
dict({
'id': 806,
'special': False,
'text': ' first',
}),
dict({
'id': 908,
'special': False,
'text': ' used',
}),
dict({
'id': 275,
'special': False,
'text': ' in',
}),
]),
}),
'generated_text': 'The word "meme" was first used in',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 50278,
'text': '<|prompter|>',
}),
dict({
'id': 1276,
'text': 'What',
}),
dict({
'id': 310,
'text': ' is',
}),
dict({
'id': 247,
'text': ' a',
}),
dict({
'id': 1167,
'text': ' mem',
}),
dict({
'id': 70,
'text': 'e',
}),
dict({
'id': 13,
'text': ',',
}),
dict({
'id': 285,
'text': ' and',
}),
dict({
'id': 752,
'text': ' what',
}),
dict({
'id': 434,
'text': "'s",
}),
dict({
'id': 253,
'text': ' the',
}),
dict({
'id': 2892,
'text': ' history',
}),
dict({
'id': 3212,
'text': ' behind',
}),
dict({
'id': 436,
'text': ' this',
}),
dict({
'id': 3159,
'text': ' word',
}),
dict({
'id': 32,
'text': '?',
}),
dict({
'id': 0,
'text': '<|endoftext|>',
}),
dict({
'id': 50281,
'text': '<|assistant|>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 510,
'special': False,
'text': 'The',
}),
dict({
'id': 3159,
'special': False,
'text': ' word',
}),
dict({
'id': 346,
'special': False,
'text': ' "',
}),
dict({
'id': 6441,
'special': False,
'text': 'mem',
}),
dict({
'id': 70,
'special': False,
'text': 'e',
}),
dict({
'id': 3,
'special': False,
'text': '"',
}),
dict({
'id': 369,
'special': False,
'text': ' was',
}),
dict({
'id': 806,
'special': False,
'text': ' first',
}),
dict({
'id': 908,
'special': False,
'text': ' used',
}),
dict({
'id': 275,
'special': False,
'text': ' in',
}),
]),
}),
'generated_text': 'The word "meme" was first used in',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 50278,
'text': '<|prompter|>',
}),
dict({
'id': 1276,
'text': 'What',
}),
dict({
'id': 310,
'text': ' is',
}),
dict({
'id': 247,
'text': ' a',
}),
dict({
'id': 1167,
'text': ' mem',
}),
dict({
'id': 70,
'text': 'e',
}),
dict({
'id': 13,
'text': ',',
}),
dict({
'id': 285,
'text': ' and',
}),
dict({
'id': 752,
'text': ' what',
}),
dict({
'id': 434,
'text': "'s",
}),
dict({
'id': 253,
'text': ' the',
}),
dict({
'id': 2892,
'text': ' history',
}),
dict({
'id': 3212,
'text': ' behind',
}),
dict({
'id': 436,
'text': ' this',
}),
dict({
'id': 3159,
'text': ' word',
}),
dict({
'id': 32,
'text': '?',
}),
dict({
'id': 0,
'text': '<|endoftext|>',
}),
dict({
'id': 50281,
'text': '<|assistant|>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 510,
'special': False,
'text': 'The',
}),
dict({
'id': 3159,
'special': False,
'text': ' word',
}),
dict({
'id': 346,
'special': False,
'text': ' "',
}),
dict({
'id': 6441,
'special': False,
'text': 'mem',
}),
dict({
'id': 70,
'special': False,
'text': 'e',
}),
dict({
'id': 3,
'special': False,
'text': '"',
}),
dict({
'id': 369,
'special': False,
'text': ' was',
}),
dict({
'id': 806,
'special': False,
'text': ' first',
}),
dict({
'id': 908,
'special': False,
'text': ' used',
}),
dict({
'id': 275,
'special': False,
'text': ' in',
}),
]),
}),
'generated_text': 'The word "meme" was first used in',
}),
]) ])
# --- # ---

View File

@ -8,22 +8,18 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 563, 'id': 563,
'logprob': None,
'text': 'def', 'text': 'def',
}), }),
dict({ dict({
'id': 942, 'id': 942,
'logprob': -5.1367188,
'text': ' print', 'text': ' print',
}), }),
dict({ dict({
'id': 62, 'id': 62,
'logprob': -0.24450684,
'text': '_', 'text': '_',
}), }),
dict({ dict({
'id': 7196, 'id': 7196,
'logprob': -6.9609375,
'text': 'hello', 'text': 'hello',
}), }),
]), ]),
@ -31,13 +27,11 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 1241, 'id': 1241,
'logprob': -0.9863281,
'special': False, 'special': False,
'text': '():', 'text': '():',
}), }),
dict({ dict({
'id': 258, 'id': 258,
'logprob': -0.21447754,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -46,37 +40,31 @@
}), }),
dict({ dict({
'id': 942, 'id': 942,
'logprob': -0.43701172,
'special': False, 'special': False,
'text': ' print', 'text': ' print',
}), }),
dict({ dict({
'id': 372, 'id': 372,
'logprob': -0.5361328,
'special': False, 'special': False,
'text': '("', 'text': '("',
}), }),
dict({ dict({
'id': 7371, 'id': 7371,
'logprob': -0.44555664,
'special': False, 'special': False,
'text': 'Hello', 'text': 'Hello',
}), }),
dict({ dict({
'id': 9956, 'id': 9956,
'logprob': -1.2412109,
'special': False, 'special': False,
'text': ' World', 'text': ' World',
}), }),
dict({ dict({
'id': 8657, 'id': 8657,
'logprob': -0.7583008,
'special': False, 'special': False,
'text': '!")', 'text': '!")',
}), }),
dict({ dict({
'id': 185, 'id': 185,
'logprob': -0.76171875,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -85,7 +73,6 @@
}), }),
dict({ dict({
'id': 185, 'id': 185,
'logprob': -0.20837402,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -94,7 +81,6 @@
}), }),
dict({ dict({
'id': 1018, 'id': 1018,
'logprob': -1.2470703,
'special': False, 'special': False,
'text': 'print', 'text': 'print',
}), }),
@ -110,29 +96,377 @@
# --- # ---
# name: test_flash_santacoder_load # name: test_flash_santacoder_load
list([ list([
''' dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 563,
'text': 'def',
}),
dict({
'id': 942,
'text': ' print',
}),
dict({
'id': 62,
'text': '_',
}),
dict({
'id': 7196,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 1241,
'special': False,
'text': '():',
}),
dict({
'id': 258,
'special': False,
'text': '''
''',
}),
dict({
'id': 942,
'special': False,
'text': ' print',
}),
dict({
'id': 372,
'special': False,
'text': '("',
}),
dict({
'id': 7371,
'special': False,
'text': 'Hello',
}),
dict({
'id': 9956,
'special': False,
'text': ' World',
}),
dict({
'id': 8657,
'special': False,
'text': '!")',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 1018,
'special': False,
'text': 'print',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World!") print("Hello World!")
print print
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 563,
'text': 'def',
}),
dict({
'id': 942,
'text': ' print',
}),
dict({
'id': 62,
'text': '_',
}),
dict({
'id': 7196,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 1241,
'special': False,
'text': '():',
}),
dict({
'id': 258,
'special': False,
'text': '''
''',
}),
dict({
'id': 942,
'special': False,
'text': ' print',
}),
dict({
'id': 372,
'special': False,
'text': '("',
}),
dict({
'id': 7371,
'special': False,
'text': 'Hello',
}),
dict({
'id': 9956,
'special': False,
'text': ' World',
}),
dict({
'id': 8657,
'special': False,
'text': '!")',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 1018,
'special': False,
'text': 'print',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World!") print("Hello World!")
print print
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 563,
'text': 'def',
}),
dict({
'id': 942,
'text': ' print',
}),
dict({
'id': 62,
'text': '_',
}),
dict({
'id': 7196,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 1241,
'special': False,
'text': '():',
}),
dict({
'id': 258,
'special': False,
'text': '''
''',
}),
dict({
'id': 942,
'special': False,
'text': ' print',
}),
dict({
'id': 372,
'special': False,
'text': '("',
}),
dict({
'id': 7371,
'special': False,
'text': 'Hello',
}),
dict({
'id': 9956,
'special': False,
'text': ' World',
}),
dict({
'id': 8657,
'special': False,
'text': '!")',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 1018,
'special': False,
'text': 'print',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World!") print("Hello World!")
print print
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 563,
'text': 'def',
}),
dict({
'id': 942,
'text': ' print',
}),
dict({
'id': 62,
'text': '_',
}),
dict({
'id': 7196,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 1241,
'special': False,
'text': '():',
}),
dict({
'id': 258,
'special': False,
'text': '''
''',
}),
dict({
'id': 942,
'special': False,
'text': ' print',
}),
dict({
'id': 372,
'special': False,
'text': '("',
}),
dict({
'id': 7371,
'special': False,
'text': 'Hello',
}),
dict({
'id': 9956,
'special': False,
'text': ' World',
}),
dict({
'id': 8657,
'special': False,
'text': '!")',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 185,
'special': False,
'text': '''
''',
}),
dict({
'id': 1018,
'special': False,
'text': 'print',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World!") print("Hello World!")
print print
''', ''',
}),
]) ])
# --- # ---

View File

@ -8,22 +8,18 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 589, 'id': 589,
'logprob': None,
'text': 'def', 'text': 'def',
}), }),
dict({ dict({
'id': 1459, 'id': 1459,
'logprob': -5.6289062,
'text': ' print', 'text': ' print',
}), }),
dict({ dict({
'id': 81, 'id': 81,
'logprob': -1.6005859,
'text': '_', 'text': '_',
}), }),
dict({ dict({
'id': 7656, 'id': 7656,
'logprob': -5.9921875,
'text': 'hello', 'text': 'hello',
}), }),
]), ]),
@ -31,13 +27,11 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 2262, 'id': 2262,
'logprob': -0.7705078,
'special': False, 'special': False,
'text': '():', 'text': '():',
}), }),
dict({ dict({
'id': 284, 'id': 284,
'logprob': -0.2590332,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -46,37 +40,31 @@
}), }),
dict({ dict({
'id': 1459, 'id': 1459,
'logprob': -0.39379883,
'special': False, 'special': False,
'text': ' print', 'text': ' print',
}), }),
dict({ dict({
'id': 440, 'id': 440,
'logprob': -0.61376953,
'special': False, 'special': False,
'text': '("', 'text': '("',
}), }),
dict({ dict({
'id': 8279, 'id': 8279,
'logprob': -0.47338867,
'special': False, 'special': False,
'text': 'Hello', 'text': 'Hello',
}), }),
dict({ dict({
'id': 10896, 'id': 10896,
'logprob': -1.5068359,
'special': False, 'special': False,
'text': ' World', 'text': ' World',
}), }),
dict({ dict({
'id': 657, 'id': 657,
'logprob': -0.80810547,
'special': False, 'special': False,
'text': '")', 'text': '")',
}), }),
dict({ dict({
'id': 203, 'id': 203,
'logprob': -0.7397461,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -85,7 +73,6 @@
}), }),
dict({ dict({
'id': 203, 'id': 203,
'logprob': -0.35229492,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -94,7 +81,6 @@
}), }),
dict({ dict({
'id': 589, 'id': 589,
'logprob': -1.0371094,
'special': False, 'special': False,
'text': 'def', 'text': 'def',
}), }),
@ -117,22 +103,18 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 589, 'id': 589,
'logprob': None,
'text': 'def', 'text': 'def',
}), }),
dict({ dict({
'id': 1459, 'id': 1459,
'logprob': -5.6289062,
'text': ' print', 'text': ' print',
}), }),
dict({ dict({
'id': 81, 'id': 81,
'logprob': -1.6005859,
'text': '_', 'text': '_',
}), }),
dict({ dict({
'id': 7656, 'id': 7656,
'logprob': -5.9921875,
'text': 'hello', 'text': 'hello',
}), }),
]), ]),
@ -140,13 +122,11 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 2262, 'id': 2262,
'logprob': -0.7451172,
'special': False, 'special': False,
'text': '():', 'text': '():',
}), }),
dict({ dict({
'id': 284, 'id': 284,
'logprob': -0.21325684,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -155,55 +135,46 @@
}), }),
dict({ dict({
'id': 5741, 'id': 5741,
'logprob': -5.734375,
'special': False, 'special': False,
'text': ' logging', 'text': ' logging',
}), }),
dict({ dict({
'id': 32, 'id': 32,
'logprob': 0.0,
'special': False, 'special': False,
'text': '.', 'text': '.',
}), }),
dict({ dict({
'id': 1338, 'id': 1338,
'logprob': -0.3232422,
'special': False, 'special': False,
'text': 'info', 'text': 'info',
}), }),
dict({ dict({
'id': 463, 'id': 463,
'logprob': -1.0380859,
'special': False, 'special': False,
'text': "('", 'text': "('",
}), }),
dict({ dict({
'id': 8279, 'id': 8279,
'logprob': -0.8378906,
'special': False, 'special': False,
'text': 'Hello', 'text': 'Hello',
}), }),
dict({ dict({
'id': 30, 'id': 30,
'logprob': -1.9501953,
'special': False, 'special': False,
'text': ',', 'text': ',',
}), }),
dict({ dict({
'id': 10896, 'id': 10896,
'logprob': -1.3476562,
'special': False, 'special': False,
'text': ' World', 'text': ' World',
}), }),
dict({ dict({
'id': 683, 'id': 683,
'logprob': -1.796875,
'special': False, 'special': False,
'text': "')", 'text': "')",
}), }),
dict({ dict({
'id': 203, 'id': 203,
'logprob': -0.9873047,
'special': False, 'special': False,
'text': ''' 'text': '''
@ -212,7 +183,6 @@
}), }),
dict({ dict({
'id': 0, 'id': 0,
'logprob': -0.7495117,
'special': True, 'special': True,
'text': '<|endoftext|>', 'text': '<|endoftext|>',
}), }),
@ -227,29 +197,377 @@
# --- # ---
# name: test_flash_starcoder_load # name: test_flash_starcoder_load
list([ list([
''' dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 589,
'text': 'def',
}),
dict({
'id': 1459,
'text': ' print',
}),
dict({
'id': 81,
'text': '_',
}),
dict({
'id': 7656,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 2262,
'special': False,
'text': '():',
}),
dict({
'id': 284,
'special': False,
'text': '''
''',
}),
dict({
'id': 1459,
'special': False,
'text': ' print',
}),
dict({
'id': 440,
'special': False,
'text': '("',
}),
dict({
'id': 8279,
'special': False,
'text': 'Hello',
}),
dict({
'id': 10896,
'special': False,
'text': ' World',
}),
dict({
'id': 657,
'special': False,
'text': '")',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 589,
'special': False,
'text': 'def',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World") print("Hello World")
def def
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 589,
'text': 'def',
}),
dict({
'id': 1459,
'text': ' print',
}),
dict({
'id': 81,
'text': '_',
}),
dict({
'id': 7656,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 2262,
'special': False,
'text': '():',
}),
dict({
'id': 284,
'special': False,
'text': '''
''',
}),
dict({
'id': 1459,
'special': False,
'text': ' print',
}),
dict({
'id': 440,
'special': False,
'text': '("',
}),
dict({
'id': 8279,
'special': False,
'text': 'Hello',
}),
dict({
'id': 10896,
'special': False,
'text': ' World',
}),
dict({
'id': 657,
'special': False,
'text': '")',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 589,
'special': False,
'text': 'def',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World") print("Hello World")
def def
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 589,
'text': 'def',
}),
dict({
'id': 1459,
'text': ' print',
}),
dict({
'id': 81,
'text': '_',
}),
dict({
'id': 7656,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 2262,
'special': False,
'text': '():',
}),
dict({
'id': 284,
'special': False,
'text': '''
''',
}),
dict({
'id': 1459,
'special': False,
'text': ' print',
}),
dict({
'id': 440,
'special': False,
'text': '("',
}),
dict({
'id': 8279,
'special': False,
'text': 'Hello',
}),
dict({
'id': 10896,
'special': False,
'text': ' World',
}),
dict({
'id': 657,
'special': False,
'text': '")',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 589,
'special': False,
'text': 'def',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World") print("Hello World")
def def
''', ''',
''' }),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 589,
'text': 'def',
}),
dict({
'id': 1459,
'text': ' print',
}),
dict({
'id': 81,
'text': '_',
}),
dict({
'id': 7656,
'text': 'hello',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 2262,
'special': False,
'text': '():',
}),
dict({
'id': 284,
'special': False,
'text': '''
''',
}),
dict({
'id': 1459,
'special': False,
'text': ' print',
}),
dict({
'id': 440,
'special': False,
'text': '("',
}),
dict({
'id': 8279,
'special': False,
'text': 'Hello',
}),
dict({
'id': 10896,
'special': False,
'text': ' World',
}),
dict({
'id': 657,
'special': False,
'text': '")',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 203,
'special': False,
'text': '''
''',
}),
dict({
'id': 589,
'special': False,
'text': 'def',
}),
]),
}),
'generated_text': '''
(): ():
print("Hello World") print("Hello World")
def def
''', ''',
}),
]) ])
# --- # ---

View File

@ -8,7 +8,6 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 0, 'id': 0,
'logprob': None,
'text': '<pad>', 'text': '<pad>',
}), }),
]), ]),
@ -16,31 +15,26 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 926, 'id': 926,
'logprob': -4.3554688,
'special': False, 'special': False,
'text': 'To', 'text': 'To',
}), }),
dict({ dict({
'id': 18295, 'id': 18295,
'logprob': -7.7734375,
'special': False, 'special': False,
'text': ' sell', 'text': ' sell',
}), }),
dict({ dict({
'id': 7868, 'id': 7868,
'logprob': -3.9257812,
'special': False, 'special': False,
'text': ' things', 'text': ' things',
}), }),
dict({ dict({
'id': 260, 'id': 260,
'logprob': -2.4179688,
'special': False, 'special': False,
'text': '.', 'text': '.',
}), }),
dict({ dict({
'id': 1, 'id': 1,
'logprob': 0.0,
'special': True, 'special': True,
'text': '</s>', 'text': '</s>',
}), }),
@ -58,7 +52,6 @@
'prefill': list([ 'prefill': list([
dict({ dict({
'id': 0, 'id': 0,
'logprob': None,
'text': '<pad>', 'text': '<pad>',
}), }),
]), ]),
@ -66,61 +59,51 @@
'tokens': list([ 'tokens': list([
dict({ dict({
'id': 16017, 'id': 16017,
'logprob': -1.3505859,
'special': False, 'special': False,
'text': 'blue', 'text': 'blue',
}), }),
dict({ dict({
'id': 20495, 'id': 20495,
'logprob': -0.50439453,
'special': False, 'special': False,
'text': ' sky', 'text': ' sky',
}), }),
dict({ dict({
'id': 259, 'id': 259,
'logprob': -1.2011719,
'special': False, 'special': False,
'text': ' ', 'text': ' ',
}), }),
dict({ dict({
'id': 15484, 'id': 15484,
'logprob': -2.8378906,
'special': False, 'special': False,
'text': 'appear', 'text': 'appear',
}), }),
dict({ dict({
'id': 345, 'id': 345,
'logprob': -0.87597656,
'special': False, 'special': False,
'text': 'ed', 'text': 'ed',
}), }),
dict({ dict({
'id': 288, 'id': 288,
'logprob': -1.8447266,
'special': False, 'special': False,
'text': ' to', 'text': ' to',
}), }),
dict({ dict({
'id': 35622, 'id': 35622,
'logprob': -7.1445312,
'special': False, 'special': False,
'text': ' cloud', 'text': ' cloud',
}), }),
dict({ dict({
'id': 263, 'id': 263,
'logprob': -1.2929688,
'special': False, 'special': False,
'text': 's', 'text': 's',
}), }),
dict({ dict({
'id': 14701, 'id': 14701,
'logprob': -3.0761719,
'special': False, 'special': False,
'text': ' above', 'text': ' above',
}), }),
dict({ dict({
'id': 751, 'id': 751,
'logprob': -4.4375,
'special': False, 'special': False,
'text': ' all', 'text': ' all',
}), }),
@ -131,9 +114,193 @@
# --- # ---
# name: test_mt0_base_load # name: test_mt0_base_load
list([ list([
'Because it is blue', dict({
'Because it is blue', 'details': dict({
'Because it is blue', 'best_of_sequences': None,
'Because it is blue', 'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
'generated_tokens': 6,
'prefill': list([
dict({
'id': 0,
'text': '<pad>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 259,
'special': False,
'text': '',
}),
dict({
'id': 39261,
'special': False,
'text': 'Because',
}),
dict({
'id': 609,
'special': False,
'text': ' it',
}),
dict({
'id': 339,
'special': False,
'text': ' is',
}),
dict({
'id': 16017,
'special': False,
'text': ' blue',
}),
dict({
'id': 1,
'special': True,
'text': '</s>',
}),
]),
}),
'generated_text': 'Because it is blue',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
'generated_tokens': 6,
'prefill': list([
dict({
'id': 0,
'text': '<pad>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 259,
'special': False,
'text': '',
}),
dict({
'id': 39261,
'special': False,
'text': 'Because',
}),
dict({
'id': 609,
'special': False,
'text': ' it',
}),
dict({
'id': 339,
'special': False,
'text': ' is',
}),
dict({
'id': 16017,
'special': False,
'text': ' blue',
}),
dict({
'id': 1,
'special': True,
'text': '</s>',
}),
]),
}),
'generated_text': 'Because it is blue',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
'generated_tokens': 6,
'prefill': list([
dict({
'id': 0,
'text': '<pad>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 259,
'special': False,
'text': '',
}),
dict({
'id': 39261,
'special': False,
'text': 'Because',
}),
dict({
'id': 609,
'special': False,
'text': ' it',
}),
dict({
'id': 339,
'special': False,
'text': ' is',
}),
dict({
'id': 16017,
'special': False,
'text': ' blue',
}),
dict({
'id': 1,
'special': True,
'text': '</s>',
}),
]),
}),
'generated_text': 'Because it is blue',
}),
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.EndOfSequenceToken: 'eos_token'>,
'generated_tokens': 6,
'prefill': list([
dict({
'id': 0,
'text': '<pad>',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 259,
'special': False,
'text': '',
}),
dict({
'id': 39261,
'special': False,
'text': 'Because',
}),
dict({
'id': 609,
'special': False,
'text': ' it',
}),
dict({
'id': 339,
'special': False,
'text': ' is',
}),
dict({
'id': 16017,
'special': False,
'text': ' blue',
}),
dict({
'id': 1,
'special': True,
'text': '</s>',
}),
]),
}),
'generated_text': 'Because it is blue',
}),
]) ])
# --- # ---

View File

@ -10,7 +10,7 @@ def bloom_560(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560, snapshot): async def test_bloom_560m(bloom_560, snapshot_test):
await health_check(bloom_560, 60) await health_check(bloom_560, 60)
response = await bloom_560.generate( response = await bloom_560.generate(
@ -21,11 +21,11 @@ async def test_bloom_560m(bloom_560, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_all_params(bloom_560, snapshot): async def test_bloom_560m_all_params(bloom_560, snapshot_test):
await health_check(bloom_560, 60) await health_check(bloom_560, 60)
response = await bloom_560.generate( response = await bloom_560.generate(
@ -44,11 +44,11 @@ async def test_bloom_560m_all_params(bloom_560, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560, generate_load, snapshot): async def test_bloom_560m_load(bloom_560, generate_load, snapshot_test):
await health_check(bloom_560, 60) await health_check(bloom_560, 60)
responses = await generate_load( responses = await generate_load(
@ -60,4 +60,4 @@ async def test_bloom_560m_load(bloom_560, generate_load, snapshot):
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -10,7 +10,7 @@ def bloom_560m_sharded(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot): async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot_test):
await health_check(bloom_560m_sharded, 60) await health_check(bloom_560m_sharded, 60)
response = await bloom_560m_sharded.generate( response = await bloom_560m_sharded.generate(
@ -21,11 +21,13 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapshot): async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, snapshot_test
):
await health_check(bloom_560m_sharded, 60) await health_check(bloom_560m_sharded, 60)
responses = await generate_load( responses = await generate_load(
@ -37,4 +39,4 @@ async def test_bloom_560m_sharded_load(bloom_560m_sharded, generate_load, snapsh
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -11,18 +11,18 @@ def flash_llama(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama(flash_llama, snapshot): async def test_flash_llama(flash_llama, snapshot_test):
await health_check(flash_llama, 120) await health_check(flash_llama, 120)
response = await flash_llama.generate("Test request", max_new_tokens=10) response = await flash_llama.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_all_params(flash_llama, snapshot): async def test_flash_llama_all_params(flash_llama, snapshot_test):
await health_check(flash_llama, 120) await health_check(flash_llama, 120)
response = await flash_llama.generate( response = await flash_llama.generate(
@ -41,16 +41,16 @@ async def test_flash_llama_all_params(flash_llama, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_load(flash_llama, generate_load, snapshot): async def test_flash_llama_load(flash_llama, generate_load, snapshot_test):
await health_check(flash_llama, 120) await health_check(flash_llama, 120)
responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_llama, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -10,7 +10,7 @@ def flash_neox(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox, snapshot): async def test_flash_neox(flash_neox, snapshot_test):
await health_check(flash_neox, 240) await health_check(flash_neox, 240)
response = await flash_neox.generate( response = await flash_neox.generate(
@ -19,11 +19,11 @@ async def test_flash_neox(flash_neox, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, snapshot): async def test_flash_neox_load(flash_neox, generate_load, snapshot_test):
await health_check(flash_neox, 240) await health_check(flash_neox, 240)
responses = await generate_load( responses = await generate_load(
@ -35,4 +35,4 @@ async def test_flash_neox_load(flash_neox, generate_load, snapshot):
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -10,17 +10,17 @@ def flash_santacoder(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, snapshot): async def test_flash_santacoder(flash_santacoder, snapshot_test):
await health_check(flash_santacoder, 60) await health_check(flash_santacoder, 60)
response = await flash_santacoder.generate("def print_hello", max_new_tokens=10) response = await flash_santacoder.generate("def print_hello", max_new_tokens=10)
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot): async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot_test):
await health_check(flash_santacoder, 60) await health_check(flash_santacoder, 60)
responses = await generate_load( responses = await generate_load(
@ -29,4 +29,4 @@ async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot):
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -11,18 +11,18 @@ def flash_starcoder(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder(flash_starcoder, snapshot): async def test_flash_starcoder(flash_starcoder, snapshot_test):
await health_check(flash_starcoder, 240) await health_check(flash_starcoder, 240)
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10) response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_default_params(flash_starcoder, snapshot): async def test_flash_starcoder_default_params(flash_starcoder, snapshot_test):
await health_check(flash_starcoder, 240) await health_check(flash_starcoder, 240)
response = await flash_starcoder.generate( response = await flash_starcoder.generate(
@ -30,12 +30,12 @@ async def test_flash_starcoder_default_params(flash_starcoder, snapshot):
) )
assert response.details.generated_tokens == 12 assert response.details.generated_tokens == 12
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot_test):
await health_check(flash_starcoder, 240) await health_check(flash_starcoder, 240)
responses = await generate_load( responses = await generate_load(
@ -44,4 +44,4 @@ async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)

View File

@ -10,7 +10,7 @@ def mt0_base(launcher):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base(mt0_base, snapshot): async def test_mt0_base(mt0_base, snapshot_test):
await health_check(mt0_base, 60) await health_check(mt0_base, 60)
response = await mt0_base.generate( response = await mt0_base.generate(
@ -21,11 +21,11 @@ async def test_mt0_base(mt0_base, snapshot):
) )
assert response.details.generated_tokens == 5 assert response.details.generated_tokens == 5
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base_all_params(mt0_base, snapshot): async def test_mt0_base_all_params(mt0_base, snapshot_test):
await health_check(mt0_base, 60) await health_check(mt0_base, 60)
response = await mt0_base.generate( response = await mt0_base.generate(
@ -44,11 +44,11 @@ async def test_mt0_base_all_params(mt0_base, snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == snapshot assert snapshot_test(response)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mt0_base_load(mt0_base, generate_load, snapshot): async def test_mt0_base_load(mt0_base, generate_load, snapshot_test):
await health_check(mt0_base, 60) await health_check(mt0_base, 60)
responses = await generate_load( responses = await generate_load(
@ -60,4 +60,4 @@ async def test_mt0_base_load(mt0_base, generate_load, snapshot):
assert len(responses) == 4 assert len(responses) == 4
assert responses == snapshot assert snapshot_test(responses)