From 17a9bb962e3838819125b1d97e35028eef4b104e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 7 Mar 2025 10:11:13 +0100 Subject: [PATCH] Update the old test. --- ...t_flash_llama_completion_stream_usage.json | 29 ++--- .../models/test_completion_prompts.py | 116 ++++++------------ 2 files changed, 47 insertions(+), 98 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json index 8564e8ce..fbb3669f 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json @@ -12,7 +12,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338471, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -32,7 +32,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338471, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -52,7 +52,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338471, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -72,7 +72,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338471, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -92,7 +92,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -112,7 +112,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -132,7 +132,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -152,7 +152,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -172,7 +172,7 @@ "logprobs": null } ], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", @@ -192,16 +192,7 @@ "logprobs": null } ], - "created": 1741274364, - "id": "", - "model": "meta-llama/Llama-3.1-8B-Instruct", - "object": "chat.completion.chunk", - "system_fingerprint": "3.1.2-dev0-native", - "usage": null - }, - { - "choices": [], - "created": 1741274364, + "created": 1741338472, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index e39e16a1..27988ef9 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -2,8 +2,9 @@ import pytest import requests import json from aiohttp import ClientSession +from huggingface_hub import InferenceClient -from text_generation.types import Completion, ChatCompletionChunk +from text_generation.types import Completion @pytest.fixture(scope="module") @@ -52,54 +53,35 @@ def test_flash_llama_completion_single_prompt( async def test_flash_llama_completion_stream_usage( flash_llama_completion, response_snapshot ): - url = f"{flash_llama_completion.base_url}/v1/chat/completions" - request = { - "model": "tgi", - "messages": [ + client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1") + stream = client.chat_completion( + model="tgi", + messages=[ { "role": "user", "content": "What is Deep Learning?", } ], - "max_tokens": 10, - "temperature": 0.0, - "stream_options": {"include_usage": True}, - "stream": True, - } + max_tokens=10, + temperature=0.0, + stream_options={"include_usage": True}, + stream=True, + ) string = "" chunks = [] had_usage = False - async with ClientSession(headers=flash_llama_completion.headers) as session: - async with session.post(url, json=request) as response: - # iterate over the stream - async for chunk in response.content.iter_any(): - # remove "data:" - chunk = chunk.decode().split("\n\n") - # remove "data:" if present - chunk = [c.replace("data:", "") for c in chunk] - # remove empty strings - chunk = [c for c in chunk if c] - # remove completion marking chunk - chunk = [c for c in chunk if c != " [DONE]"] - # parse json - chunk = [json.loads(c) for c in chunk] + for chunk in stream: + # remove "data:" + chunks.append(chunk) + print(f"Chunk {chunk}") + if len(chunk.choices) == 1: + index = chunk.choices[0].index + assert index == 0 + string += chunk.choices[0].delta.content + if chunk.usage: + assert not had_usage + had_usage = True - for c in chunk: - chunks.append(ChatCompletionChunk(**c)) - assert "choices" in c - if len(c["choices"]) == 1: - index = c["choices"][0]["index"] - assert index == 0 - string += c["choices"][0]["delta"]["content"] - - has_usage = c["usage"] is not None - assert not had_usage - if has_usage: - had_usage = True - elif c["usage"]: - had_usage = True - else: - raise RuntimeError(f"Expected different payload: {c}") assert had_usage assert ( string @@ -107,53 +89,29 @@ async def test_flash_llama_completion_stream_usage( ) assert chunks == response_snapshot - request = { - "model": "tgi", - "messages": [ + stream = client.chat_completion( + model="tgi", + messages=[ { "role": "user", "content": "What is Deep Learning?", } ], - "max_tokens": 10, - "temperature": 0.0, - "stream": True, - } + max_tokens=10, + temperature=0.0, + # No usage + # stream_options={"include_usage": True}, + stream=True, + ) string = "" chunks = [] had_usage = False - async with ClientSession(headers=flash_llama_completion.headers) as session: - async with session.post(url, json=request) as response: - # iterate over the stream - async for chunk in response.content.iter_any(): - # remove "data:" - chunk = chunk.decode().split("\n\n") - # remove "data:" if present - chunk = [c.replace("data:", "") for c in chunk] - # remove empty strings - chunk = [c for c in chunk if c] - # remove completion marking chunk - chunk = [c for c in chunk if c != " [DONE]"] - # parse json - chunk = [json.loads(c) for c in chunk] - - for c in chunk: - chunks.append(ChatCompletionChunk(**c)) - assert "choices" in c - if len(c["choices"]) == 1: - index = c["choices"][0]["index"] - assert index == 0 - string += c["choices"][0]["delta"]["content"] - - has_usage = c["usage"] is not None - assert not had_usage - if has_usage: - had_usage = True - elif c["usage"]: - had_usage = True - else: - raise RuntimeError("Expected different payload") - assert not had_usage + for chunk in stream: + chunks.append(chunk) + assert chunk.usage is None + assert len(chunk.choices) == 1 + assert chunk.choices[0].index == 0 + string += chunk.choices[0].delta.content assert ( string == "**Deep Learning: An Overview**\n=====================================\n\n"