Make GPTQ test less flaky.

This commit is contained in:
Nicolas Patry 2023-11-28 18:04:52 +01:00
parent ba552e1a82
commit cf7c17c66b
2 changed files with 17 additions and 8 deletions

View File

@ -24,6 +24,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
def serialize( def serialize(
self, self,
data, data,
@ -58,7 +59,7 @@ class ResponseComparator(JSONSnapshotExtension):
return ( return (
token.id == other.id token.id == other.id
and token.text == other.text and token.text == other.text
and math.isclose(token.logprob, other.logprob, rel_tol=0.2) and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
and token.special == other.special and token.special == other.special
) )
@ -68,7 +69,7 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=0.2) math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
) )
@ -148,6 +149,10 @@ class ResponseComparator(JSONSnapshotExtension):
) )
class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
@ -193,6 +198,10 @@ class ProcessLauncherHandle(LauncherHandle):
def response_snapshot(snapshot): def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator) return snapshot.use_extension(ResponseComparator)
@pytest.fixture
def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop(): def event_loop():

View File

@ -15,20 +15,20 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", "def geometric_mean(L: List[float]):",
max_new_tokens=20, max_new_tokens=20,
decoder_input_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == generous_response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq_default_params( async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, response_snapshot flash_starcoder_gptq, generous_response_snapshot
): ):
response = await flash_starcoder_gptq.generate( response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", "def geometric_mean(L: List[float]):",
@ -39,13 +39,13 @@ async def test_flash_starcoder_gptq_default_params(
seed=0, seed=0,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 20
assert response == response_snapshot assert response == generous_response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_starcoder_gptq_load( async def test_flash_starcoder_gptq_load(
flash_starcoder_gptq, generate_load, response_snapshot flash_starcoder_gptq, generate_load, generous_response_snapshot
): ):
responses = await generate_load( responses = await generate_load(
flash_starcoder_gptq, flash_starcoder_gptq,
@ -57,4 +57,4 @@ async def test_flash_starcoder_gptq_load(
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot assert responses == generous_response_snapshot