text-generation-inference/integration-tests/models/test_flash_starcoder.py
2023-05-15 23:36:30 +02:00

48 lines
1.2 KiB
Python

import pytest
from utils import health_check
@pytest.fixture(scope="module")
def flash_starcoder(launcher):
with launcher("bigcode/starcoder", num_shard=2) as client:
yield client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder(flash_starcoder, snapshot_test):
await health_check(flash_starcoder, 240)
response = await flash_starcoder.generate("def print_hello", max_new_tokens=10)
assert response.details.generated_tokens == 10
assert snapshot_test(response)
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_default_params(flash_starcoder, snapshot_test):
await health_check(flash_starcoder, 240)
response = await flash_starcoder.generate(
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
)
assert response.details.generated_tokens == 12
assert snapshot_test(response)
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_load(flash_starcoder, generate_load, snapshot_test):
await health_check(flash_starcoder, 240)
responses = await generate_load(
flash_starcoder, "def print_hello", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert snapshot_test(responses)