Update the old test.

This commit is contained in:
Nicolas Patry 2025-03-07 10:11:13 +01:00
parent 062be12812
commit 17a9bb962e
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
2 changed files with 47 additions and 98 deletions

View File

@ -12,7 +12,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338471,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -32,7 +32,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338471,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -52,7 +52,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338471,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -72,7 +72,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338471,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -92,7 +92,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -112,7 +112,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -132,7 +132,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -152,7 +152,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -172,7 +172,7 @@
"logprobs": null
}
],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
@ -192,16 +192,7 @@
"logprobs": null
}
],
"created": 1741274364,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "3.1.2-dev0-native",
"usage": null
},
{
"choices": [],
"created": 1741274364,
"created": 1741338472,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",

View File

@ -2,8 +2,9 @@ import pytest
import requests
import json
from aiohttp import ClientSession
from huggingface_hub import InferenceClient
from text_generation.types import Completion, ChatCompletionChunk
from text_generation.types import Completion
@pytest.fixture(scope="module")
@ -52,54 +53,35 @@ def test_flash_llama_completion_single_prompt(
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
request = {
"model": "tgi",
"messages": [
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream_options": {"include_usage": True},
"stream": True,
}
max_tokens=10,
temperature=0.0,
stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for chunk in stream:
# remove "data:"
chunks.append(chunk)
print(f"Chunk {chunk}")
if len(chunk.choices) == 1:
index = chunk.choices[0].index
assert index == 0
string += chunk.choices[0].delta.content
if chunk.usage:
assert not had_usage
had_usage = True
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
elif c["usage"]:
had_usage = True
else:
raise RuntimeError(f"Expected different payload: {c}")
assert had_usage
assert (
string
@ -107,53 +89,29 @@ async def test_flash_llama_completion_stream_usage(
)
assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
max_tokens=10,
temperature=0.0,
# No usage
# stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
elif c["usage"]:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
for chunk in stream:
chunks.append(chunk)
assert chunk.usage is None
assert len(chunk.choices) == 1
assert chunk.choices[0].index == 0
string += chunk.choices[0].delta.content
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"