fix: update tests for streaming tools

This commit is contained in:
drbh 2024-02-28 03:56:37 +00:00
parent 0fc7237380
commit b5cacca1dc
2 changed files with 22 additions and 1 deletions

View File

@ -24,6 +24,7 @@ from text_generation.types import (
BestOfSequence, BestOfSequence,
Grammar, Grammar,
ChatComplete, ChatComplete,
ChatCompletionChunk,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -61,7 +62,15 @@ class ResponseComparator(JSONSnapshotExtension):
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if (
isinstance(choices, List)
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data)
return ChatComplete(**data) return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
if isinstance(data, List): if isinstance(data, List):
@ -151,6 +160,11 @@ class ResponseComparator(JSONSnapshotExtension):
response.choices[0].message.content == other.choices[0].message.content response.choices[0].message.content == other.choices[0].message.content
) )
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool: def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details( return response.generated_text == other.generated_text and eq_details(
response.details, other.details response.details, other.details
@ -169,6 +183,14 @@ class ResponseComparator(JSONSnapshotExtension):
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
) )
if isinstance(serialized_data[0], ChatCompletionChunk):
return len(snapshot_data) == len(serialized_data) and all(
[
eq_chat_complete_chunk(r, o)
for r, o in zip(serialized_data, snapshot_data)
]
)
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
) )

View File

@ -234,7 +234,6 @@ async def test_flash_llama_grammar_tools_stream(
count = 0 count = 0
async for response in responses: async for response in responses:
print(response)
count += 1 count += 1
assert count == 20 assert count == 20