diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 58659319..8d949ddb 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -1,44 +1,46 @@ -# import pytest -# -# -# @pytest.fixture(scope="module") -# def neox_handle(launcher): -# with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def neox(neox_handle): -# await neox_handle.health(300) -# return neox_handle.client -# -# -# @pytest.mark.asyncio -# async def test_neox(neox, response_snapshot): -# response = await neox.generate( -# "<|USER|>What's your mood today?<|ASSISTANT|>", -# max_new_tokens=10, -# decoder_input_details=True, -# ) -# -# assert response.details.generated_tokens == 10 -# assert response == response_snapshot -# -# -# @pytest.mark.asyncio -# async def test_neox_load(neox, generate_load, response_snapshot): -# responses = await generate_load( -# neox, -# "<|USER|>What's your mood today?<|ASSISTANT|>", -# max_new_tokens=10, -# n=4, -# ) -# -# generated_texts = [r.generated_text for r in responses] -# -# assert len(generated_texts) == 4 -# assert generated_texts, all( -# [text == generated_texts[0] for text in generated_texts] -# ) -# -# assert responses == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def neox_handle(launcher): + with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox(neox_handle): + await neox_handle.health(300) + return neox_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox, response_snapshot): + response = await neox.generate( + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox, generate_load, response_snapshot): + responses = await generate_load( + neox, + "<|USER|>What's your mood today?<|ASSISTANT|>", + max_new_tokens=10, + n=4, + ) + + generated_texts = [r.generated_text for r in responses] + + assert len(generated_texts) == 4 + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index 97f2d8a5..fd691a1a 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -1,40 +1,42 @@ -# import pytest -# -# -# @pytest.fixture(scope="module") -# def neox_sharded_handle(launcher): -# with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: -# yield handle -# -# -# @pytest.fixture(scope="module") -# async def neox_sharded(neox_sharded_handle): -# await neox_sharded_handle.health(300) -# return neox_sharded_handle.client -# -# -# @pytest.mark.asyncio -# async def test_neox(neox_sharded, response_snapshot): -# response = await neox_sharded.generate( -# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", -# max_new_tokens=10, -# decoder_input_details=True, -# ) -# -# assert response.details.generated_tokens == 10 -# assert response == response_snapshot -# -# -# @pytest.mark.asyncio -# async def test_neox_load(neox_sharded, generate_load, response_snapshot): -# responses = await generate_load( -# neox_sharded, -# "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", -# max_new_tokens=10, -# n=4, -# ) -# -# assert len(responses) == 4 -# assert all([r.generated_text == responses[0].generated_text for r in responses]) -# -# assert responses == response_snapshot +import pytest + + +@pytest.fixture(scope="module") +def neox_sharded_handle(launcher): + with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def neox_sharded(neox_sharded_handle): + await neox_sharded_handle.health(300) + return neox_sharded_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox(neox_sharded, response_snapshot): + response = await neox_sharded.generate( + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_neox_load(neox_sharded, generate_load, response_snapshot): + responses = await generate_load( + neox_sharded, + "<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot