add flash bigcode models

This commit is contained in:
OlivierDehaene 2023-05-04 11:07:39 +02:00
parent 421372f271
commit 9a9244937b
5 changed files with 260 additions and 3 deletions

View File

@ -1,20 +1,18 @@
import subprocess
import time
import contextlib
import pytest
import asyncio
import os
import docker
from datetime import datetime
from docker.errors import NotFound
from typing import Optional, List
from aiohttp import ClientConnectorError
from text_generation import AsyncClient
from text_generation.types import Response
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@pytest.fixture(scope="module")
@ -92,10 +90,15 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1
env = {}
if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
container = client.containers.run(
DOCKER_IMAGE,
command=args,
name=container_name,
environment=env,
auto_remove=True,
detach=True,
device_requests=[

View File

@ -0,0 +1,101 @@
# serializer version: 1
# name: test_flash_santacoder
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 804,
'logprob': None,
'text': 'Test',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 25,
'logprob': -2.828125,
'special': False,
'text': ':',
}),
dict({
'id': 287,
'logprob': -1.5703125,
'special': False,
'text': ' "',
}),
dict({
'id': 385,
'logprob': -0.03955078,
'special': False,
'text': ' +',
}),
dict({
'id': 1028,
'logprob': -1.453125,
'special': False,
'text': ' request',
}),
dict({
'id': 13,
'logprob': -0.796875,
'special': False,
'text': '.',
}),
dict({
'id': 1832,
'logprob': -1.4296875,
'special': False,
'text': 'toString',
}),
dict({
'id': 782,
'logprob': -0.17871094,
'special': False,
'text': '());',
}),
dict({
'id': 259,
'logprob': -1.359375,
'special': False,
'text': '''
''',
}),
dict({
'id': 294,
'logprob': -1.6484375,
'special': False,
'text': ' }',
}),
dict({
'id': 437,
'logprob': -1.0625,
'special': False,
'text': '''
''',
}),
]),
}),
'generated_text': '''
: " + request.toString());
}
''',
})
# ---
# name: test_flash_santacoder_load
list([
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)),
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -0,0 +1,89 @@
# serializer version: 1
# name: test_flash_starcoder
dict({
'details': dict({
'best_of_sequences': None,
'finish_reason': <FinishReason.Length: 'length'>,
'generated_tokens': 10,
'prefill': list([
dict({
'id': 1006,
'logprob': None,
'text': 'Test',
}),
]),
'seed': None,
'tokens': list([
dict({
'id': 30,
'logprob': -2.734375,
'special': False,
'text': ',',
}),
dict({
'id': 892,
'logprob': -2.828125,
'special': False,
'text': ' String',
}),
dict({
'id': 1984,
'logprob': -3.28125,
'special': False,
'text': ' url',
}),
dict({
'id': 30,
'logprob': -0.796875,
'special': False,
'text': ',',
}),
dict({
'id': 892,
'logprob': -1.0390625,
'special': False,
'text': ' String',
}),
dict({
'id': 1411,
'logprob': -2.078125,
'special': False,
'text': ' method',
}),
dict({
'id': 30,
'logprob': -0.49609375,
'special': False,
'text': ',',
}),
dict({
'id': 892,
'logprob': -1.0546875,
'special': False,
'text': ' String',
}),
dict({
'id': 3361,
'logprob': -1.71875,
'special': False,
'text': ' body',
}),
dict({
'id': 27,
'logprob': -0.69921875,
'special': False,
'text': ')',
}),
]),
}),
'generated_text': ', String url, String method, String body)',
})
# ---
# name: test_flash_starcoder_load
list([
Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)),
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
])
# ---

View File

@ -0,0 +1,32 @@
import pytest
from utils import health_check
@pytest.fixture(scope="module")
def flash_santacoder(launcher):
with launcher("bigcode/santacoder") as client:
yield client
@pytest.mark.asyncio
async def test_flash_santacoder(flash_santacoder, snapshot):
await health_check(flash_santacoder, 60)
response = await flash_santacoder.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert response == snapshot
@pytest.mark.asyncio
async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot):
await health_check(flash_santacoder, 60)
responses = await generate_load(
flash_santacoder, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses == snapshot

View File

@ -0,0 +1,32 @@
import pytest
from utils import health_check
@pytest.fixture(scope="module")
def flash_starcoder(launcher):
with launcher("bigcode/large-model", num_shard=2) as client:
yield client
@pytest.mark.asyncio
async def test_flash_starcoder(flash_starcoder, snapshot):
await health_check(flash_starcoder, 60)
response = await flash_starcoder.generate("Test request", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert response == snapshot
@pytest.mark.asyncio
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
await health_check(flash_starcoder, 60)
responses = await generate_load(
flash_starcoder, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert responses == snapshot