diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index cc64c064..b4d35697 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,20 +1,18 @@ import subprocess -import time import contextlib import pytest import asyncio import os import docker -from datetime import datetime from docker.errors import NotFound from typing import Optional, List -from aiohttp import ClientConnectorError from text_generation import AsyncClient from text_generation.types import Response DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) +HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @pytest.fixture(scope="module") @@ -92,10 +90,15 @@ def launcher(event_loop): gpu_count = num_shard if num_shard is not None else 1 + env = {} + if HUGGING_FACE_HUB_TOKEN is not None: + env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN + container = client.containers.run( DOCKER_IMAGE, command=args, name=container_name, + environment=env, auto_remove=True, detach=True, device_requests=[ diff --git a/integration-tests/models/__snapshots__/test_flash_santacoder.ambr b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr new file mode 100644 index 00000000..bb5dfbba --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_santacoder.ambr @@ -0,0 +1,101 @@ +# serializer version: 1 +# name: test_flash_santacoder + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 804, + 'logprob': None, + 'text': 'Test', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 25, + 'logprob': -2.828125, + 'special': False, + 'text': ':', + }), + dict({ + 'id': 287, + 'logprob': -1.5703125, + 'special': False, + 'text': ' "', + }), + dict({ + 'id': 385, + 'logprob': -0.03955078, + 'special': False, + 'text': ' +', + }), + dict({ + 'id': 1028, + 'logprob': -1.453125, + 'special': False, + 'text': ' request', + }), + dict({ + 'id': 13, + 'logprob': -0.796875, + 'special': False, + 'text': '.', + }), + dict({ + 'id': 1832, + 'logprob': -1.4296875, + 'special': False, + 'text': 'toString', + }), + dict({ + 'id': 782, + 'logprob': -0.17871094, + 'special': False, + 'text': '());', + }), + dict({ + 'id': 259, + 'logprob': -1.359375, + 'special': False, + 'text': ''' + + + ''', + }), + dict({ + 'id': 294, + 'logprob': -1.6484375, + 'special': False, + 'text': ' }', + }), + dict({ + 'id': 437, + 'logprob': -1.0625, + 'special': False, + 'text': ''' + + + + ''', + }), + ]), + }), + 'generated_text': ''' + : " + request.toString()); + } + + + ''', + }) +# --- +# name: test_flash_santacoder_load + list([ + Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), + Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), + Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)), + Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)), + ]) +# --- diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder.ambr b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr new file mode 100644 index 00000000..2ecdbf6c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder.ambr @@ -0,0 +1,89 @@ +# serializer version: 1 +# name: test_flash_starcoder + dict({ + 'details': dict({ + 'best_of_sequences': None, + 'finish_reason': , + 'generated_tokens': 10, + 'prefill': list([ + dict({ + 'id': 1006, + 'logprob': None, + 'text': 'Test', + }), + ]), + 'seed': None, + 'tokens': list([ + dict({ + 'id': 30, + 'logprob': -2.734375, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 892, + 'logprob': -2.828125, + 'special': False, + 'text': ' String', + }), + dict({ + 'id': 1984, + 'logprob': -3.28125, + 'special': False, + 'text': ' url', + }), + dict({ + 'id': 30, + 'logprob': -0.796875, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 892, + 'logprob': -1.0390625, + 'special': False, + 'text': ' String', + }), + dict({ + 'id': 1411, + 'logprob': -2.078125, + 'special': False, + 'text': ' method', + }), + dict({ + 'id': 30, + 'logprob': -0.49609375, + 'special': False, + 'text': ',', + }), + dict({ + 'id': 892, + 'logprob': -1.0546875, + 'special': False, + 'text': ' String', + }), + dict({ + 'id': 3361, + 'logprob': -1.71875, + 'special': False, + 'text': ' body', + }), + dict({ + 'id': 27, + 'logprob': -0.69921875, + 'special': False, + 'text': ')', + }), + ]), + }), + 'generated_text': ', String url, String method, String body)', + }) +# --- +# name: test_flash_starcoder_load + list([ + Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)), + Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), + Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), + Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)), + ]) +# --- diff --git a/integration-tests/models/test_flash_santacoder.py b/integration-tests/models/test_flash_santacoder.py new file mode 100644 index 00000000..ea357b64 --- /dev/null +++ b/integration-tests/models/test_flash_santacoder.py @@ -0,0 +1,32 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_santacoder(launcher): + with launcher("bigcode/santacoder") as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_santacoder(flash_santacoder, snapshot): + await health_check(flash_santacoder, 60) + + response = await flash_santacoder.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot): + await health_check(flash_santacoder, 60) + + responses = await generate_load( + flash_santacoder, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + + assert responses == snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py new file mode 100644 index 00000000..406caa6c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder.py @@ -0,0 +1,32 @@ +import pytest + +from utils import health_check + + +@pytest.fixture(scope="module") +def flash_starcoder(launcher): + with launcher("bigcode/large-model", num_shard=2) as client: + yield client + + +@pytest.mark.asyncio +async def test_flash_starcoder(flash_starcoder, snapshot): + await health_check(flash_starcoder, 60) + + response = await flash_starcoder.generate("Test request", max_new_tokens=10) + + assert response.details.generated_tokens == 10 + assert response == snapshot + + +@pytest.mark.asyncio +async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot): + await health_check(flash_starcoder, 60) + + responses = await generate_load( + flash_starcoder, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + + assert responses == snapshot