From ae7f3aeba1ccfa38d9da86bcdb398f71afc99f41 Mon Sep 17 00:00:00 2001 From: baptiste Date: Mon, 23 Jun 2025 12:27:32 +0000 Subject: [PATCH] update conftest --- integration-tests/conftest.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 534aaaea..9cc33416 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -51,6 +51,7 @@ from text_generation.types import ( ChatComplete, ChatCompletionChunk, ChatCompletionComplete, + Completion, Details, Grammar, InputToken, @@ -160,6 +161,7 @@ class ResponseComparator(JSONSnapshotExtension): or isinstance(data, ChatComplete) or isinstance(data, ChatCompletionChunk) or isinstance(data, ChatCompletionComplete) + or isinstance(data, Completion) or isinstance(data, OAIChatCompletionChunk) or isinstance(data, OAICompletion) ): @@ -216,6 +218,8 @@ class ResponseComparator(JSONSnapshotExtension): if isinstance(choices, List) and len(choices) >= 1: if "delta" in choices[0]: return ChatCompletionChunk(**data) + if "text" in choices[0]: + return Completion(**data) return ChatComplete(**data) else: return Response(**data) @@ -308,6 +312,9 @@ class ResponseComparator(JSONSnapshotExtension): ) ) + def eq_completion(response: Completion, other: Completion) -> bool: + return response.choices[0].text == other.choices[0].text + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: return ( response.choices[0].message.content == other.choices[0].message.content @@ -352,6 +359,11 @@ class ResponseComparator(JSONSnapshotExtension): if len(serialized_data) == 0: return len(snapshot_data) == len(serialized_data) + if isinstance(serialized_data[0], Completion): + return len(snapshot_data) == len(serialized_data) and all( + [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + if isinstance(serialized_data[0], ChatComplete): return len(snapshot_data) == len(serialized_data) and all( [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]