update conftest

This commit is contained in:
baptiste 2025-06-23 12:27:32 +00:00
parent a32025f931
commit ae7f3aeba1

View File

@ -51,6 +51,7 @@ from text_generation.types import (
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
Completion,
Details,
Grammar,
InputToken,
@ -160,6 +161,7 @@ class ResponseComparator(JSONSnapshotExtension):
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
or isinstance(data, Completion)
or isinstance(data, OAIChatCompletionChunk)
or isinstance(data, OAICompletion)
):
@ -216,6 +218,8 @@ class ResponseComparator(JSONSnapshotExtension):
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
else:
return Response(**data)
@ -308,6 +312,9 @@ class ResponseComparator(JSONSnapshotExtension):
)
)
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return (
response.choices[0].message.content == other.choices[0].message.content
@ -352,6 +359,11 @@ class ResponseComparator(JSONSnapshotExtension):
if len(serialized_data) == 0:
return len(snapshot_data) == len(serialized_data)
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]