mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-22 15:50:17 +00:00
update conftest
This commit is contained in:
parent
a32025f931
commit
ae7f3aeba1
@ -51,6 +51,7 @@ from text_generation.types import (
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionComplete,
|
||||
Completion,
|
||||
Details,
|
||||
Grammar,
|
||||
InputToken,
|
||||
@ -160,6 +161,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
or isinstance(data, ChatComplete)
|
||||
or isinstance(data, ChatCompletionChunk)
|
||||
or isinstance(data, ChatCompletionComplete)
|
||||
or isinstance(data, Completion)
|
||||
or isinstance(data, OAIChatCompletionChunk)
|
||||
or isinstance(data, OAICompletion)
|
||||
):
|
||||
@ -216,6 +218,8 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
if isinstance(choices, List) and len(choices) >= 1:
|
||||
if "delta" in choices[0]:
|
||||
return ChatCompletionChunk(**data)
|
||||
if "text" in choices[0]:
|
||||
return Completion(**data)
|
||||
return ChatComplete(**data)
|
||||
else:
|
||||
return Response(**data)
|
||||
@ -308,6 +312,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
)
|
||||
)
|
||||
|
||||
def eq_completion(response: Completion, other: Completion) -> bool:
|
||||
return response.choices[0].text == other.choices[0].text
|
||||
|
||||
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||
return (
|
||||
response.choices[0].message.content == other.choices[0].message.content
|
||||
@ -352,6 +359,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
if len(serialized_data) == 0:
|
||||
return len(snapshot_data) == len(serialized_data)
|
||||
|
||||
if isinstance(serialized_data[0], Completion):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
if isinstance(serialized_data[0], ChatComplete):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
|
Loading…
Reference in New Issue
Block a user