Update the old test.

This commit is contained in:
Nicolas Patry 2025-03-07 10:11:13 +01:00
parent 062be12812
commit 17a9bb962e
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
2 changed files with 47 additions and 98 deletions

View File

@ -12,7 +12,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338471,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -32,7 +32,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338471,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -52,7 +52,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338471,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -72,7 +72,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338471,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -92,7 +92,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -112,7 +112,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -132,7 +132,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -152,7 +152,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -172,7 +172,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -192,16 +192,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1741274364, "created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [],
"created": 1741274364,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",

View File

@ -2,8 +2,9 @@ import pytest
import requests import requests
import json import json
from aiohttp import ClientSession from aiohttp import ClientSession
from huggingface_hub import InferenceClient
from text_generation.types import Completion, ChatCompletionChunk from text_generation.types import Completion
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -52,54 +53,35 @@ def test_flash_llama_completion_single_prompt(
async def test_flash_llama_completion_stream_usage( async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
url = f"{flash_llama_completion.base_url}/v1/chat/completions" client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
request = { stream = client.chat_completion(
"model": "tgi", model="tgi",
"messages": [ messages=[
{ {
"role": "user", "role": "user",
"content": "What is Deep Learning?", "content": "What is Deep Learning?",
} }
], ],
"max_tokens": 10, max_tokens=10,
"temperature": 0.0, temperature=0.0,
"stream_options": {"include_usage": True}, stream_options={"include_usage": True},
"stream": True, stream=True,
} )
string = "" string = ""
chunks = [] chunks = []
had_usage = False had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session: for chunk in stream:
async with session.post(url, json=request) as response: # remove "data:"
# iterate over the stream chunks.append(chunk)
async for chunk in response.content.iter_any(): print(f"Chunk {chunk}")
# remove "data:" if len(chunk.choices) == 1:
chunk = chunk.decode().split("\n\n") index = chunk.choices[0].index
# remove "data:" if present assert index == 0
chunk = [c.replace("data:", "") for c in chunk] string += chunk.choices[0].delta.content
# remove empty strings if chunk.usage:
chunk = [c for c in chunk if c] assert not had_usage
# remove completion marking chunk had_usage = True
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
elif c["usage"]:
had_usage = True
else:
raise RuntimeError(f"Expected different payload: {c}")
assert had_usage assert had_usage
assert ( assert (
string string
@ -107,53 +89,29 @@ async def test_flash_llama_completion_stream_usage(
) )
assert chunks == response_snapshot assert chunks == response_snapshot
request = { stream = client.chat_completion(
"model": "tgi", model="tgi",
"messages": [ messages=[
{ {
"role": "user", "role": "user",
"content": "What is Deep Learning?", "content": "What is Deep Learning?",
} }
], ],
"max_tokens": 10, max_tokens=10,
"temperature": 0.0, temperature=0.0,
"stream": True, # No usage
} # stream_options={"include_usage": True},
stream=True,
)
string = "" string = ""
chunks = [] chunks = []
had_usage = False had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session: for chunk in stream:
async with session.post(url, json=request) as response: chunks.append(chunk)
# iterate over the stream assert chunk.usage is None
async for chunk in response.content.iter_any(): assert len(chunk.choices) == 1
# remove "data:" assert chunk.choices[0].index == 0
chunk = chunk.decode().split("\n\n") string += chunk.choices[0].delta.content
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
elif c["usage"]:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert ( assert (
string string
== "**Deep Learning: An Overview**\n=====================================\n\n" == "**Deep Learning: An Overview**\n=====================================\n\n"