update conftest

This commit is contained in:
baptiste 2025-06-23 12:27:32 +00:00
parent a32025f931
commit ae7f3aeba1

View File

@ -51,6 +51,7 @@ from text_generation.types import (
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
Completion,
Details, Details,
Grammar, Grammar,
InputToken, InputToken,
@ -160,6 +161,7 @@ class ResponseComparator(JSONSnapshotExtension):
or isinstance(data, ChatComplete) or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk) or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete) or isinstance(data, ChatCompletionComplete)
or isinstance(data, Completion)
or isinstance(data, OAIChatCompletionChunk) or isinstance(data, OAIChatCompletionChunk)
or isinstance(data, OAICompletion) or isinstance(data, OAICompletion)
): ):
@ -216,6 +218,8 @@ class ResponseComparator(JSONSnapshotExtension):
if isinstance(choices, List) and len(choices) >= 1: if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]: if "delta" in choices[0]:
return ChatCompletionChunk(**data) return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data) return ChatComplete(**data)
else: else:
return Response(**data) return Response(**data)
@ -308,6 +312,9 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return ( return (
response.choices[0].message.content == other.choices[0].message.content response.choices[0].message.content == other.choices[0].message.content
@ -352,6 +359,11 @@ class ResponseComparator(JSONSnapshotExtension):
if len(serialized_data) == 0: if len(serialized_data) == 0:
return len(snapshot_data) == len(serialized_data) return len(snapshot_data) == len(serialized_data)
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete): if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]