mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add flash bigcode models
This commit is contained in:
parent
421372f271
commit
9a9244937b
@ -1,20 +1,18 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import docker
|
import docker
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from aiohttp import ClientConnectorError
|
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import Response
|
from text_generation.types import Response
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -92,10 +90,15 @@ def launcher(event_loop):
|
|||||||
|
|
||||||
gpu_count = num_shard if num_shard is not None else 1
|
gpu_count = num_shard if num_shard is not None else 1
|
||||||
|
|
||||||
|
env = {}
|
||||||
|
if HUGGING_FACE_HUB_TOKEN is not None:
|
||||||
|
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
||||||
|
|
||||||
container = client.containers.run(
|
container = client.containers.run(
|
||||||
DOCKER_IMAGE,
|
DOCKER_IMAGE,
|
||||||
command=args,
|
command=args,
|
||||||
name=container_name,
|
name=container_name,
|
||||||
|
environment=env,
|
||||||
auto_remove=True,
|
auto_remove=True,
|
||||||
detach=True,
|
detach=True,
|
||||||
device_requests=[
|
device_requests=[
|
||||||
|
@ -0,0 +1,101 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_santacoder
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 804,
|
||||||
|
'logprob': None,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 25,
|
||||||
|
'logprob': -2.828125,
|
||||||
|
'special': False,
|
||||||
|
'text': ':',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 287,
|
||||||
|
'logprob': -1.5703125,
|
||||||
|
'special': False,
|
||||||
|
'text': ' "',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 385,
|
||||||
|
'logprob': -0.03955078,
|
||||||
|
'special': False,
|
||||||
|
'text': ' +',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1028,
|
||||||
|
'logprob': -1.453125,
|
||||||
|
'special': False,
|
||||||
|
'text': ' request',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 13,
|
||||||
|
'logprob': -0.796875,
|
||||||
|
'special': False,
|
||||||
|
'text': '.',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1832,
|
||||||
|
'logprob': -1.4296875,
|
||||||
|
'special': False,
|
||||||
|
'text': 'toString',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 782,
|
||||||
|
'logprob': -0.17871094,
|
||||||
|
'special': False,
|
||||||
|
'text': '());',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 259,
|
||||||
|
'logprob': -1.359375,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 294,
|
||||||
|
'logprob': -1.6484375,
|
||||||
|
'special': False,
|
||||||
|
'text': ' }',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 437,
|
||||||
|
'logprob': -1.0625,
|
||||||
|
'special': False,
|
||||||
|
'text': '''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': '''
|
||||||
|
: " + request.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
''',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_santacoder_load
|
||||||
|
list([
|
||||||
|
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||||
|
])
|
||||||
|
# ---
|
@ -0,0 +1,89 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_flash_starcoder
|
||||||
|
dict({
|
||||||
|
'details': dict({
|
||||||
|
'best_of_sequences': None,
|
||||||
|
'finish_reason': <FinishReason.Length: 'length'>,
|
||||||
|
'generated_tokens': 10,
|
||||||
|
'prefill': list([
|
||||||
|
dict({
|
||||||
|
'id': 1006,
|
||||||
|
'logprob': None,
|
||||||
|
'text': 'Test',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'seed': None,
|
||||||
|
'tokens': list([
|
||||||
|
dict({
|
||||||
|
'id': 30,
|
||||||
|
'logprob': -2.734375,
|
||||||
|
'special': False,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 892,
|
||||||
|
'logprob': -2.828125,
|
||||||
|
'special': False,
|
||||||
|
'text': ' String',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1984,
|
||||||
|
'logprob': -3.28125,
|
||||||
|
'special': False,
|
||||||
|
'text': ' url',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 30,
|
||||||
|
'logprob': -0.796875,
|
||||||
|
'special': False,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 892,
|
||||||
|
'logprob': -1.0390625,
|
||||||
|
'special': False,
|
||||||
|
'text': ' String',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 1411,
|
||||||
|
'logprob': -2.078125,
|
||||||
|
'special': False,
|
||||||
|
'text': ' method',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 30,
|
||||||
|
'logprob': -0.49609375,
|
||||||
|
'special': False,
|
||||||
|
'text': ',',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 892,
|
||||||
|
'logprob': -1.0546875,
|
||||||
|
'special': False,
|
||||||
|
'text': ' String',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 3361,
|
||||||
|
'logprob': -1.71875,
|
||||||
|
'special': False,
|
||||||
|
'text': ' body',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 27,
|
||||||
|
'logprob': -0.69921875,
|
||||||
|
'special': False,
|
||||||
|
'text': ')',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'generated_text': ', String url, String method, String body)',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_flash_starcoder_load
|
||||||
|
list([
|
||||||
|
Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||||
|
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||||
|
])
|
||||||
|
# ---
|
32
integration-tests/models/test_flash_santacoder.py
Normal file
32
integration-tests/models/test_flash_santacoder.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_santacoder(launcher):
|
||||||
|
with launcher("bigcode/santacoder") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_santacoder(flash_santacoder, snapshot):
|
||||||
|
await health_check(flash_santacoder, 60)
|
||||||
|
|
||||||
|
response = await flash_santacoder.generate("Test request", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot):
|
||||||
|
await health_check(flash_santacoder, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_santacoder, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert responses == snapshot
|
32
integration-tests/models/test_flash_starcoder.py
Normal file
32
integration-tests/models/test_flash_starcoder.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import health_check
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_starcoder(launcher):
|
||||||
|
with launcher("bigcode/large-model", num_shard=2) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder(flash_starcoder, snapshot):
|
||||||
|
await health_check(flash_starcoder, 60)
|
||||||
|
|
||||||
|
response = await flash_starcoder.generate("Test request", max_new_tokens=10)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
|
||||||
|
await health_check(flash_starcoder, 60)
|
||||||
|
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
assert responses == snapshot
|
Loading…
Reference in New Issue
Block a user