mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Make GPTQ test less flaky.
This commit is contained in:
parent
ba552e1a82
commit
cf7c17c66b
@ -24,6 +24,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
|||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
|
rtol = 0.2
|
||||||
def serialize(
|
def serialize(
|
||||||
self,
|
self,
|
||||||
data,
|
data,
|
||||||
@ -58,7 +59,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
return (
|
return (
|
||||||
token.id == other.id
|
token.id == other.id
|
||||||
and token.text == other.text
|
and token.text == other.text
|
||||||
and math.isclose(token.logprob, other.logprob, rel_tol=0.2)
|
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||||
and token.special == other.special
|
and token.special == other.special
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -68,7 +69,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
prefill_token.id == other.id
|
prefill_token.id == other.id
|
||||||
and prefill_token.text == other.text
|
and prefill_token.text == other.text
|
||||||
and (
|
and (
|
||||||
math.isclose(prefill_token.logprob, other.logprob, rel_tol=0.2)
|
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol)
|
||||||
if prefill_token.logprob is not None
|
if prefill_token.logprob is not None
|
||||||
else prefill_token.logprob == other.logprob
|
else prefill_token.logprob == other.logprob
|
||||||
)
|
)
|
||||||
@ -148,6 +149,10 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerousResponseComparator(ResponseComparator):
|
||||||
|
# Needed for GPTQ with exllama which has serious numerical fluctuations.
|
||||||
|
rtol = 0.75
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}")
|
||||||
@ -193,6 +198,10 @@ class ProcessLauncherHandle(LauncherHandle):
|
|||||||
def response_snapshot(snapshot):
|
def response_snapshot(snapshot):
|
||||||
return snapshot.use_extension(ResponseComparator)
|
return snapshot.use_extension(ResponseComparator)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def generous_response_snapshot(snapshot):
|
||||||
|
return snapshot.use_extension(GenerousResponseComparator)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop():
|
def event_loop():
|
||||||
|
@ -15,20 +15,20 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 20
|
||||||
assert response == response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
flash_starcoder_gptq, response_snapshot
|
flash_starcoder_gptq, generous_response_snapshot
|
||||||
):
|
):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
@ -39,13 +39,13 @@ async def test_flash_starcoder_gptq_default_params(
|
|||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 20
|
||||||
assert response == response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
flash_starcoder_gptq, generate_load, response_snapshot
|
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_starcoder_gptq,
|
flash_starcoder_gptq,
|
||||||
@ -57,4 +57,4 @@ async def test_flash_starcoder_gptq_load(
|
|||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == generous_response_snapshot
|
||||||
|
Loading…
Reference in New Issue
Block a user