From 24c0f1cc7a68213fcb76fa5ffcda7194000f90c1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 30 Jun 2023 21:55:37 +0000 Subject: [PATCH] Adding (failing) integration tests. --- integration-tests/models/test_mpt.py | 48 +++++++++++++++++++++ server/text_generation_server/models/mpt.py | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 integration-tests/models/test_mpt.py diff --git a/integration-tests/models/test_mpt.py b/integration-tests/models/test_mpt.py new file mode 100644 index 00000000..d58a8c5a --- /dev/null +++ b/integration-tests/models/test_mpt.py @@ -0,0 +1,48 @@ +import pytest + + +@pytest.fixture(scope="module") +def mpt_sharded_handle(launcher): + with launcher("mosaicml/mpt-7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def mpt_sharded(mpt_sharded_handle): + await mpt_sharded_handle.health(300) + return mpt_sharded_handle.client + + +@pytest.mark.asyncio +async def test_mpt(mpt_sharded, response_snapshot): + response = await mpt_sharded.generate( + "What is Deep Learning?", + max_new_tokens=17, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 17 + assert ( + response.generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_mpt_load(mpt_sharded, generate_load, response_snapshot): + responses = await generate_load( + mpt_sharded, + "What is Deep Learning?", + max_new_tokens=17, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert ( + responses[0].generated_text + == " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural" + ) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 889b3c95..b38f6218 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -66,7 +66,7 @@ class MPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - requires_padding=True, + requires_padding=False, dtype=dtype, device=device, rank=rank,