diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index a9645153..158b9dde 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -24,6 +24,7 @@ from text_generation.types import ( BestOfSequence, Grammar, ChatComplete, + ChatCompletionChunk, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -61,7 +62,15 @@ class ResponseComparator(JSONSnapshotExtension): def convert_data(data): data = json.loads(data) if isinstance(data, Dict) and "choices" in data: + choices = data["choices"] + if ( + isinstance(choices, List) + and len(choices) >= 1 + and "delta" in choices[0] + ): + return ChatCompletionChunk(**data) return ChatComplete(**data) + if isinstance(data, Dict): return Response(**data) if isinstance(data, List): @@ -151,6 +160,11 @@ class ResponseComparator(JSONSnapshotExtension): response.choices[0].message.content == other.choices[0].message.content ) + def eq_chat_complete_chunk( + response: ChatCompletionChunk, other: ChatCompletionChunk + ) -> bool: + return response.choices[0].delta.content == other.choices[0].delta.content + def eq_response(response: Response, other: Response) -> bool: return response.generated_text == other.generated_text and eq_details( response.details, other.details @@ -169,6 +183,14 @@ class ResponseComparator(JSONSnapshotExtension): [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] ) + if isinstance(serialized_data[0], ChatCompletionChunk): + return len(snapshot_data) == len(serialized_data) and all( + [ + eq_chat_complete_chunk(r, o) + for r, o in zip(serialized_data, snapshot_data) + ] + ) + return len(snapshot_data) == len(serialized_data) and all( [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] ) diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 0901c7ac..38570c38 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -234,7 +234,6 @@ async def test_flash_llama_grammar_tools_stream( count = 0 async for response in responses: - print(response) count += 1 assert count == 20