mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add flash bigcode models
This commit is contained in:
parent
421372f271
commit
9a9244937b
@ -1,20 +1,18 @@
|
||||
import subprocess
|
||||
import time
|
||||
import contextlib
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
import docker
|
||||
|
||||
from datetime import datetime
|
||||
from docker.errors import NotFound
|
||||
from typing import Optional, List
|
||||
from aiohttp import ClientConnectorError
|
||||
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import Response
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -92,10 +90,15 @@ def launcher(event_loop):
|
||||
|
||||
gpu_count = num_shard if num_shard is not None else 1
|
||||
|
||||
env = {}
|
||||
if HUGGING_FACE_HUB_TOKEN is not None:
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
||||
|
||||
container = client.containers.run(
|
||||
DOCKER_IMAGE,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=True,
|
||||
detach=True,
|
||||
device_requests=[
|
||||
|
@ -0,0 +1,101 @@
|
||||
# serializer version: 1
|
||||
# name: test_flash_santacoder
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 804,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 25,
|
||||
'logprob': -2.828125,
|
||||
'special': False,
|
||||
'text': ':',
|
||||
}),
|
||||
dict({
|
||||
'id': 287,
|
||||
'logprob': -1.5703125,
|
||||
'special': False,
|
||||
'text': ' "',
|
||||
}),
|
||||
dict({
|
||||
'id': 385,
|
||||
'logprob': -0.03955078,
|
||||
'special': False,
|
||||
'text': ' +',
|
||||
}),
|
||||
dict({
|
||||
'id': 1028,
|
||||
'logprob': -1.453125,
|
||||
'special': False,
|
||||
'text': ' request',
|
||||
}),
|
||||
dict({
|
||||
'id': 13,
|
||||
'logprob': -0.796875,
|
||||
'special': False,
|
||||
'text': '.',
|
||||
}),
|
||||
dict({
|
||||
'id': 1832,
|
||||
'logprob': -1.4296875,
|
||||
'special': False,
|
||||
'text': 'toString',
|
||||
}),
|
||||
dict({
|
||||
'id': 782,
|
||||
'logprob': -0.17871094,
|
||||
'special': False,
|
||||
'text': '());',
|
||||
}),
|
||||
dict({
|
||||
'id': 259,
|
||||
'logprob': -1.359375,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
dict({
|
||||
'id': 294,
|
||||
'logprob': -1.6484375,
|
||||
'special': False,
|
||||
'text': ' }',
|
||||
}),
|
||||
dict({
|
||||
'id': 437,
|
||||
'logprob': -1.0625,
|
||||
'special': False,
|
||||
'text': '''
|
||||
|
||||
|
||||
|
||||
''',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': '''
|
||||
: " + request.toString());
|
||||
}
|
||||
|
||||
|
||||
''',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_santacoder_load
|
||||
list([
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=': " + request.toString());\n }\n\n ', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=287, text=' "', logprob=-1.5859375, special=False), Token(id=385, text=' +', logprob=-0.037841797, special=False), Token(id=1028, text=' request', logprob=-1.4453125, special=False), Token(id=13, text='.', logprob=-0.79296875, special=False), Token(id=1832, text='toString', logprob=-1.4375, special=False), Token(id=782, text='());', logprob=-0.19335938, special=False), Token(id=259, text='\n ', logprob=-1.359375, special=False), Token(id=294, text=' }', logprob=-1.609375, special=False), Token(id=437, text='\n\n ', logprob=-1.0546875, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=':\n def __init__(self, name,', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=804, text='Test', logprob=None)], tokens=[Token(id=25, text=':', logprob=-2.828125, special=False), Token(id=258, text='\n ', logprob=-2.828125, special=False), Token(id=458, text=' def', logprob=-0.7421875, special=False), Token(id=945, text=' __', logprob=-0.46679688, special=False), Token(id=955, text='init', logprob=-0.00680542, special=False), Token(id=1218, text='__(', logprob=-0.0049743652, special=False), Token(id=314, text='self', logprob=-0.020629883, special=False), Token(id=11, text=',', logprob=-0.34179688, special=False), Token(id=693, text=' name', logprob=-3.40625, special=False), Token(id=11, text=',', logprob=-0.6875, special=False)], best_of_sequences=None)),
|
||||
])
|
||||
# ---
|
@ -0,0 +1,89 @@
|
||||
# serializer version: 1
|
||||
# name: test_flash_starcoder
|
||||
dict({
|
||||
'details': dict({
|
||||
'best_of_sequences': None,
|
||||
'finish_reason': <FinishReason.Length: 'length'>,
|
||||
'generated_tokens': 10,
|
||||
'prefill': list([
|
||||
dict({
|
||||
'id': 1006,
|
||||
'logprob': None,
|
||||
'text': 'Test',
|
||||
}),
|
||||
]),
|
||||
'seed': None,
|
||||
'tokens': list([
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -2.734375,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -2.828125,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
}),
|
||||
dict({
|
||||
'id': 1984,
|
||||
'logprob': -3.28125,
|
||||
'special': False,
|
||||
'text': ' url',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.796875,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -1.0390625,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
}),
|
||||
dict({
|
||||
'id': 1411,
|
||||
'logprob': -2.078125,
|
||||
'special': False,
|
||||
'text': ' method',
|
||||
}),
|
||||
dict({
|
||||
'id': 30,
|
||||
'logprob': -0.49609375,
|
||||
'special': False,
|
||||
'text': ',',
|
||||
}),
|
||||
dict({
|
||||
'id': 892,
|
||||
'logprob': -1.0546875,
|
||||
'special': False,
|
||||
'text': ' String',
|
||||
}),
|
||||
dict({
|
||||
'id': 3361,
|
||||
'logprob': -1.71875,
|
||||
'special': False,
|
||||
'text': ' body',
|
||||
}),
|
||||
dict({
|
||||
'id': 27,
|
||||
'logprob': -0.69921875,
|
||||
'special': False,
|
||||
'text': ')',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'generated_text': ', String url, String method, String body)',
|
||||
})
|
||||
# ---
|
||||
# name: test_flash_starcoder_load
|
||||
list([
|
||||
Response(generated_text=', String url, String method, String body)', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=892, text=' String', logprob=-2.84375, special=False), Token(id=1984, text=' url', logprob=-3.28125, special=False), Token(id=30, text=',', logprob=-0.796875, special=False), Token(id=892, text=' String', logprob=-1.03125, special=False), Token(id=1411, text=' method', logprob=-2.09375, special=False), Token(id=30, text=',', logprob=-0.49804688, special=False), Token(id=892, text=' String', logprob=-1.0546875, special=False), Token(id=3361, text=' body', logprob=-1.7265625, special=False), Token(id=27, text=')', logprob=-0.703125, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
Response(generated_text=', GetNext) {\n std::vector<', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=10, seed=None, prefill=[PrefillToken(id=1006, text='Test', logprob=None)], tokens=[Token(id=30, text=',', logprob=-2.734375, special=False), Token(id=1390, text=' Get', logprob=-3.421875, special=False), Token(id=3353, text='Next', logprob=-4.625, special=False), Token(id=27, text=')', logprob=-2.875, special=False), Token(id=301, text=' {', logprob=-0.14453125, special=False), Token(id=334, text='\n ', logprob=-0.17871094, special=False), Token(id=1230, text=' std', logprob=-2.328125, special=False), Token(id=403, text='::', logprob=-0.0007247925, special=False), Token(id=2402, text='vector', logprob=-0.81640625, special=False), Token(id=46, text='<', logprob=-0.0026397705, special=False)], best_of_sequences=None)),
|
||||
])
|
||||
# ---
|
32
integration-tests/models/test_flash_santacoder.py
Normal file
32
integration-tests/models/test_flash_santacoder.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_santacoder(launcher):
|
||||
with launcher("bigcode/santacoder") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_santacoder(flash_santacoder, snapshot):
|
||||
await health_check(flash_santacoder, 60)
|
||||
|
||||
response = await flash_santacoder.generate("Test request", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_santacoder_load(flash_santacoder, generate_load, snapshot):
|
||||
await health_check(flash_santacoder, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_santacoder, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
assert responses == snapshot
|
32
integration-tests/models/test_flash_starcoder.py
Normal file
32
integration-tests/models/test_flash_starcoder.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from utils import health_check
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder(launcher):
|
||||
with launcher("bigcode/large-model", num_shard=2) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder(flash_starcoder, snapshot):
|
||||
await health_check(flash_starcoder, 60)
|
||||
|
||||
response = await flash_starcoder.generate("Test request", max_new_tokens=10)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot):
|
||||
await health_check(flash_starcoder, 60)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_starcoder, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
assert responses == snapshot
|
Loading…
Reference in New Issue
Block a user