mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: update tests for streaming tools
This commit is contained in:
parent
0fc7237380
commit
b5cacca1dc
@ -24,6 +24,7 @@ from text_generation.types import (
|
||||
BestOfSequence,
|
||||
Grammar,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
)
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
@ -61,7 +62,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
def convert_data(data):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
choices = data["choices"]
|
||||
if (
|
||||
isinstance(choices, List)
|
||||
and len(choices) >= 1
|
||||
and "delta" in choices[0]
|
||||
):
|
||||
return ChatCompletionChunk(**data)
|
||||
return ChatComplete(**data)
|
||||
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
if isinstance(data, List):
|
||||
@ -151,6 +160,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
response.choices[0].message.content == other.choices[0].message.content
|
||||
)
|
||||
|
||||
def eq_chat_complete_chunk(
|
||||
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||
) -> bool:
|
||||
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||
|
||||
def eq_response(response: Response, other: Response) -> bool:
|
||||
return response.generated_text == other.generated_text and eq_details(
|
||||
response.details, other.details
|
||||
@ -169,6 +183,14 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
if isinstance(serialized_data[0], ChatCompletionChunk):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[
|
||||
eq_chat_complete_chunk(r, o)
|
||||
for r, o in zip(serialized_data, snapshot_data)
|
||||
]
|
||||
)
|
||||
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
@ -234,7 +234,6 @@ async def test_flash_llama_grammar_tools_stream(
|
||||
|
||||
count = 0
|
||||
async for response in responses:
|
||||
print(response)
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
|
Loading…
Reference in New Issue
Block a user