mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: update tests for streaming tools
This commit is contained in:
parent
0fc7237380
commit
b5cacca1dc
@ -24,6 +24,7 @@ from text_generation.types import (
|
|||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Grammar,
|
Grammar,
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
|
ChatCompletionChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
@ -61,7 +62,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
def convert_data(data):
|
def convert_data(data):
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if isinstance(data, Dict) and "choices" in data:
|
if isinstance(data, Dict) and "choices" in data:
|
||||||
|
choices = data["choices"]
|
||||||
|
if (
|
||||||
|
isinstance(choices, List)
|
||||||
|
and len(choices) >= 1
|
||||||
|
and "delta" in choices[0]
|
||||||
|
):
|
||||||
|
return ChatCompletionChunk(**data)
|
||||||
return ChatComplete(**data)
|
return ChatComplete(**data)
|
||||||
|
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
return Response(**data)
|
return Response(**data)
|
||||||
if isinstance(data, List):
|
if isinstance(data, List):
|
||||||
@ -151,6 +160,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
response.choices[0].message.content == other.choices[0].message.content
|
response.choices[0].message.content == other.choices[0].message.content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def eq_chat_complete_chunk(
|
||||||
|
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||||
|
) -> bool:
|
||||||
|
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||||
|
|
||||||
def eq_response(response: Response, other: Response) -> bool:
|
def eq_response(response: Response, other: Response) -> bool:
|
||||||
return response.generated_text == other.generated_text and eq_details(
|
return response.generated_text == other.generated_text and eq_details(
|
||||||
response.details, other.details
|
response.details, other.details
|
||||||
@ -169,6 +183,14 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(serialized_data[0], ChatCompletionChunk):
|
||||||
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
|
[
|
||||||
|
eq_chat_complete_chunk(r, o)
|
||||||
|
for r, o in zip(serialized_data, snapshot_data)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return len(snapshot_data) == len(serialized_data) and all(
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
)
|
)
|
||||||
|
@ -234,7 +234,6 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
print(response)
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 20
|
assert count == 20
|
||||||
|
Loading…
Reference in New Issue
Block a user