fix: update tests for streaming tools

This commit is contained in:
drbh 2024-02-28 03:56:37 +00:00
parent 0fc7237380
commit b5cacca1dc
2 changed files with 22 additions and 1 deletions

View File

@ -24,6 +24,7 @@ from text_generation.types import (
BestOfSequence,
Grammar,
ChatComplete,
ChatCompletionChunk,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -61,7 +62,15 @@ class ResponseComparator(JSONSnapshotExtension):
def convert_data(data):
data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if (
isinstance(choices, List)
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data)
return ChatComplete(**data)
if isinstance(data, Dict):
return Response(**data)
if isinstance(data, List):
@ -151,6 +160,11 @@ class ResponseComparator(JSONSnapshotExtension):
response.choices[0].message.content == other.choices[0].message.content
)
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details(
response.details, other.details
@ -169,6 +183,14 @@ class ResponseComparator(JSONSnapshotExtension):
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatCompletionChunk):
return len(snapshot_data) == len(serialized_data) and all(
[
eq_chat_complete_chunk(r, o)
for r, o in zip(serialized_data, snapshot_data)
]
)
return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
)

View File

@ -234,7 +234,6 @@ async def test_flash_llama_grammar_tools_stream(
count = 0
async for response in responses:
print(response)
count += 1
assert count == 20