feat(server): Add model tests

This commit is contained in:
OlivierDehaene 2022-12-08 18:32:07 +01:00
parent 31d76e238d
commit 3229fb7b44
16 changed files with 987 additions and 63 deletions

View File

@ -87,9 +87,4 @@ curl 127.0.0.1:3000/generate \
```shell ```shell
make server-dev make server-dev
make router-dev make router-dev
``` ```
## TODO:
- [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers

View File

@ -70,7 +70,7 @@ impl Batcher {
// Notify the background task that we have a new entry in the database that needs // Notify the background task that we have a new entry in the database that needs
// to be batched // to be batched
self.shared.batching_task.notify_waiters(); self.shared.batching_task.notify_one();
// Await on the response from the background task // Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender // We can safely unwrap as the background task will never drop the sender

View File

@ -8,8 +8,9 @@ gen-server:
install-transformers: install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
rm transformers || true pip uninstall transformers -y || true
rm transformers-text_generation_inference || true rm -rf transformers || true
rm -rf transformers-text_generation_inference || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
unzip text_generation_inference.zip unzip text_generation_inference.zip
rm text_generation_inference.zip rm text_generation_inference.zip

99
server/poetry.lock generated
View File

@ -22,6 +22,20 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
test_trackers = ["comet-ml", "tensorboard", "wandb"] test_trackers = ["comet-ml", "tensorboard", "wandb"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"] testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
[[package]]
name = "attrs"
version = "22.1.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.5"
[package.extras]
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
[[package]] [[package]]
name = "bitsandbytes" name = "bitsandbytes"
version = "0.35.1" version = "0.35.1"
@ -49,6 +63,17 @@ category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "exceptiongroup"
version = "1.0.4"
description = "Backport of PEP 654 (exception groups)"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.extras]
test = ["pytest (>=6)"]
[[package]] [[package]]
name = "grpcio" name = "grpcio"
version = "1.50.0" version = "1.50.0"
@ -88,6 +113,14 @@ grpcio = ">=1.50.0"
protobuf = ">=4.21.6,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "iniconfig"
version = "1.1.1"
description = "iniconfig: brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.23.4" version = "1.23.4"
@ -107,6 +140,18 @@ python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]]
name = "pluggy"
version = "1.0.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.6"
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.21.8" version = "4.21.8"
@ -137,6 +182,26 @@ python-versions = ">=3.6.8"
[package.extras] [package.extras]
diagrams = ["jinja2", "railroad-diagrams"] diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pytest"
version = "7.2.0"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]] [[package]]
name = "PyYAML" name = "PyYAML"
version = "6.0" version = "6.0"
@ -178,6 +243,14 @@ category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
category = "dev"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "torch" name = "torch"
version = "1.12.1" version = "1.12.1"
@ -220,13 +293,17 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67" content-hash = "51693654531e3229ac64bee250932ace20a60e8d45af074ae7b860ed32b25ef8"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"}, {file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"}, {file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
] ]
attrs = [
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
]
bitsandbytes = [ bitsandbytes = [
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"}, {file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"}, {file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
@ -239,6 +316,10 @@ colorama = [
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
] ]
exceptiongroup = [
{file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"},
{file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"},
]
grpcio = [ grpcio = [
{file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"}, {file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"},
{file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"}, {file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"},
@ -337,6 +418,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
{file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
] ]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
]
numpy = [ numpy = [
{file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
{file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"}, {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
@ -371,6 +456,10 @@ packaging = [
{file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
] ]
pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
protobuf = [ protobuf = [
{file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"}, {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"},
{file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"}, {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"},
@ -429,6 +518,10 @@ pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
] ]
pytest = [
{file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"},
{file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"},
]
PyYAML = [ PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
{file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"},
@ -512,6 +605,10 @@ six = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
] ]
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
torch = [ torch = [
{file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"}, {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
{file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"}, {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},

View File

@ -22,6 +22,7 @@ bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1" grpcio-tools = "^1.49.1"
pytest = "^7.2.0"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]

34
server/tests/conftest.py Normal file
View File

@ -0,0 +1,34 @@
import pytest
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
@pytest.fixture
def default_pb_parameters():
return generate_pb2.LogitsWarperParameters(
temperature=1.0,
top_k=0,
top_p=1.0,
do_sample=False,
)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small", padding_side="left")
tokenizer.bos_token_id = 0
return tokenizer

View File

@ -0,0 +1,245 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(
id=0,
requests=[default_pb_request],
size=1
)
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(default_pb_batch, bloom_560m_tokenizer, torch.device("cpu"))
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(
id=0,
requests=[req_0, req_1],
size=2
)
return BloomCausalLMBatch.from_pb(batch_pb, bloom_560m_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
with pytest.raises(ValueError):
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == []
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert generated_texts[0].tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
def test_causal_lm_generate_token_completion_multi(default_bloom, default_multi_requests_bloom_batch):
next_batch = default_multi_requests_bloom_batch
for i in range(default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
def test_batch_concatenate(default_bloom, default_bloom_batch, default_multi_requests_bloom_batch):
next_batch_0 = default_bloom_batch
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, :, -1].reshape(-1, 64, 1))
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(next_batch_1.past_key_values[i][1][:, -1:, :], past[1][1:, :, -1, :].reshape(-1, 1, 64))
for _ in range(default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert generated_texts[0].tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens -
default_bloom_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 4):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert generated_texts[0].tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens

View File

@ -0,0 +1,244 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(
id=0,
requests=[default_pb_request],
size=1
)
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(
id=0,
requests=[req_0, req_1],
size=2
)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
generated_texts, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
assert generated_texts == []
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert next_batch.all_input_ids[0][-1] == 6208
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 6208
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
def test_causal_lm_generate_token_completion(default_causal_lm, default_causal_lm_batch):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test"
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert generated_texts[0].tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
def test_causal_lm_generate_token_completion_multi(default_causal_lm, default_multi_requests_causal_lm_batch):
next_batch = default_multi_requests_causal_lm_batch
for i in range(default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test"
assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
def test_batch_concatenate(default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch):
next_batch_0 = default_causal_lm_batch
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 6208)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :])
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :])
for _ in range(default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test"
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert generated_texts[0].tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens -
default_causal_lm_batch.stopping_criterias[0].max_new_tokens -
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 4):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test Test Test Test Test Test"
assert generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
assert generated_texts[0].tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens

View File

@ -0,0 +1,252 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=2,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(
id=0,
requests=[default_pb_request],
size=1
)
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(default_pb_batch, mt0_small_tokenizer, torch.device("cpu"))
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(
id=0,
requests=[req_0, req_1],
size=2
)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, 8)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
assert batch.past_key_values is None
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
generated_texts, next_batch = default_seq2seq_lm.generate_token(default_seq2seq_lm_batch)
assert generated_texts == []
assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
assert torch.equal(next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask)
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
assert next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
assert next_batch.decoder_input_ids[0, 0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
assert next_batch.past_key_values is not None
assert all([p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values])
assert all([p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values])
assert all([p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values])
def test_seq2seq_lm_generate_token_completion(default_seq2seq_lm, default_seq2seq_lm_batch):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch):
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1]
assert generated_texts[0].tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
def test_batch_concatenate(default_seq2seq_lm, default_seq2seq_lm_batch, default_multi_requests_seq2seq_lm_batch):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids[:, 0] == 4268)
assert torch.all(next_batch.input_ids[:, 1] == 1)
assert torch.all(next_batch.attention_mask == 1)
assert torch.equal(next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0])
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
assert torch.equal(next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids)
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
assert torch.equal(next_batch.encoder_last_hidden_state[0], next_batch_0.encoder_last_hidden_state[0, -2:])
assert torch.equal(next_batch.encoder_last_hidden_state[1:], next_batch_1.encoder_last_hidden_state[:, -2:])
assert next_batch.input_lengths == [2, 2, 2]
assert next_batch.decoder_input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 2
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values])
assert all([p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values])
assert all([p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0])
assert torch.equal(next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :])
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0])
assert torch.equal(next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :])
assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0])
assert torch.equal(next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:])
assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0])
assert torch.equal(next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:])
for _ in range(3):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1]
assert generated_texts[0].tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7

View File

@ -0,0 +1,30 @@
import pytest
from text_generation.utils import weight_hub_files, download_weights, weight_files, LocalEntryNotFoundError
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", ".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,10 +1,10 @@
from text_generation.models.model import Model from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOMSharded from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation.models.galactica import Galactica, GalacticaSharded
__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"] __all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
@ -12,7 +12,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if sharded: if sharded:
return BLOOMSharded(model_name, quantize=quantize) return BLOOMSharded(model_name, quantize=quantize)
else: else:
return CausalLM(model_name, quantize=quantize) return BLOOM(model_name, quantize=quantize)
elif model_name.startswith("facebook/galactica"): elif model_name.startswith("facebook/galactica"):
if sharded: if sharded:
return GalacticaSharded(model_name, quantize=quantize) return GalacticaSharded(model_name, quantize=quantize)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional from typing import List, Optional, Type
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -13,6 +13,8 @@ from transformers.models.bloom.parallel_layers import (
) )
from text_generation.models import CausalLM from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -29,7 +31,23 @@ except Exception as e:
torch.manual_seed(0) torch.manual_seed(0)
class BLOOMSharded(CausalLM): class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(pb=pb, tokenizer=tokenizer, device=device)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False): def __init__(self, model_name: str, quantize: bool = False):
if not model_name.startswith("bigscience/bloom"): if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_name} is not supported")
@ -87,17 +105,17 @@ class BLOOMSharded(CausalLM):
@staticmethod @staticmethod
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open( with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu" file, framework="pt", device=str(device) if not quantize else "cpu"
) as f: ) as f:
for name in f.keys(): for name in f.keys():
full_name = f"transformer.{name}" full_name = f"transformer.{name}"
@ -160,9 +178,9 @@ class BLOOMSharded(CausalLM):
) )
if ( if (
type(module) type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear] in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor.transpose(1, 0), tensor.transpose(1, 0),

View File

@ -2,7 +2,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type, ClassVar
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText from text_generation.models.types import GeneratedText
@ -34,6 +34,9 @@ class CausalLMBatch:
size: int size: int
max_sequence_length: int max_sequence_length: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self): def to_pb(self):
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
@ -43,7 +46,7 @@ class CausalLMBatch:
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -144,8 +147,8 @@ class CausalLMBatch:
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
attention_mask[ attention_mask[
start_index:end_index, -batch.max_sequence_length : start_index:end_index, -batch.max_sequence_length:
] = batch.attention_mask[:, -batch.max_sequence_length :] ] = batch.attention_mask[:, -batch.max_sequence_length:]
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past past_keys, past_values = past
@ -165,20 +168,16 @@ class CausalLMBatch:
head_dim, head_dim,
) )
if batch.keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
# seq_length is last for BLOOM # seq_length is last for BLOOM
if past_keys.shape[-2] == head_dim: else:
past_keys_head_dim_last = False
padded_past_keys_shape = ( padded_past_keys_shape = (
total_batch_size, total_batch_size,
num_heads, num_heads,
head_dim, head_dim,
max_sequence_length - 1, max_sequence_length - 1,
) )
elif past_keys.shape[-1] == head_dim:
past_keys_head_dim_last = True
padded_past_keys_shape = padded_past_values_shape
else:
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
# This will run only once per layer # This will run only once per layer
if j == len(past_key_values): if j == len(past_key_values):
@ -195,24 +194,24 @@ class CausalLMBatch:
past_key_values.append((padded_past_keys, padded_past_values)) past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if past_keys_head_dim_last: if batch.keys_head_dim_last:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_sequence_length - 1):,
:, :,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_sequence_length - 1):,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, : start_index:end_index, :, -(batch.max_sequence_length - 1):, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_values[:, :, -(batch.max_sequence_length - 1):, :]
start_index += batch.size start_index += batch.size
@ -228,6 +227,7 @@ class CausalLMBatch:
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
keys_head_dim_last=batches[0].keys_head_dim_last
) )
@ -237,6 +237,9 @@ class CausalLM(Model):
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
@ -247,7 +250,7 @@ class CausalLM(Model):
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,
).eval() ).eval()
tokenizer.pad_token_id = self.model.config.pad_token_id tokenizer.pad_token_id = self.model.config.pad_token_id if self.model.config.pad_token_id is not None else self.model.config.eos_token_id
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -259,7 +262,7 @@ class CausalLM(Model):
return CausalLMBatch return CausalLMBatch
def forward( def forward(
self, input_ids, attention_mask, past_key_values: Optional = None self, input_ids, attention_mask, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs = self.model.forward(
@ -271,7 +274,7 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = ( context_manager = (
@ -309,12 +312,12 @@ class CausalLM(Model):
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length, input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_tokens, all_tokens,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
@ -397,5 +400,6 @@ class CausalLM(Model):
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length, max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last
) )
return generated_texts, next_batch return generated_texts, next_batch

View File

@ -83,7 +83,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch": ) -> "GalacticaCausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []

View File

@ -221,8 +221,8 @@ class Seq2SeqLMBatch:
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_decoder_input_length :, : start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :] ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
# Iterate over attention layers # Iterate over attention layers
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
@ -305,6 +305,9 @@ class Seq2SeqLM(Model):
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -137,8 +137,8 @@ def download_weights(model_name, extension=".safetensors"):
executor.submit(download_function, filename=filename) for filename in filenames executor.submit(download_function, filename=filename) for filename in filenames
] ]
files = [ files = [
file future.result()
for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
] ]
return files return files