mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 16:02:13 +00:00
Update the old test.
This commit is contained in:
parent
062be12812
commit
17a9bb962e
@ -12,7 +12,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338471,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -32,7 +32,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338471,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -52,7 +52,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338471,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -72,7 +72,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338471,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -92,7 +92,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -112,7 +112,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -132,7 +132,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -152,7 +152,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -172,7 +172,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@ -192,16 +192,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1741274364,
|
"created": 1741338472,
|
||||||
"id": "",
|
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"system_fingerprint": "3.1.2-dev0-native",
|
|
||||||
"usage": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"choices": [],
|
|
||||||
"created": 1741274364,
|
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -2,8 +2,9 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
from text_generation.types import Completion, ChatCompletionChunk
|
from text_generation.types import Completion
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -52,54 +53,35 @@ def test_flash_llama_completion_single_prompt(
|
|||||||
async def test_flash_llama_completion_stream_usage(
|
async def test_flash_llama_completion_stream_usage(
|
||||||
flash_llama_completion, response_snapshot
|
flash_llama_completion, response_snapshot
|
||||||
):
|
):
|
||||||
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
|
||||||
request = {
|
stream = client.chat_completion(
|
||||||
"model": "tgi",
|
model="tgi",
|
||||||
"messages": [
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "What is Deep Learning?",
|
"content": "What is Deep Learning?",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": 10,
|
max_tokens=10,
|
||||||
"temperature": 0.0,
|
temperature=0.0,
|
||||||
"stream_options": {"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
"stream": True,
|
stream=True,
|
||||||
}
|
)
|
||||||
string = ""
|
string = ""
|
||||||
chunks = []
|
chunks = []
|
||||||
had_usage = False
|
had_usage = False
|
||||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
for chunk in stream:
|
||||||
async with session.post(url, json=request) as response:
|
# remove "data:"
|
||||||
# iterate over the stream
|
chunks.append(chunk)
|
||||||
async for chunk in response.content.iter_any():
|
print(f"Chunk {chunk}")
|
||||||
# remove "data:"
|
if len(chunk.choices) == 1:
|
||||||
chunk = chunk.decode().split("\n\n")
|
index = chunk.choices[0].index
|
||||||
# remove "data:" if present
|
assert index == 0
|
||||||
chunk = [c.replace("data:", "") for c in chunk]
|
string += chunk.choices[0].delta.content
|
||||||
# remove empty strings
|
if chunk.usage:
|
||||||
chunk = [c for c in chunk if c]
|
assert not had_usage
|
||||||
# remove completion marking chunk
|
had_usage = True
|
||||||
chunk = [c for c in chunk if c != " [DONE]"]
|
|
||||||
# parse json
|
|
||||||
chunk = [json.loads(c) for c in chunk]
|
|
||||||
|
|
||||||
for c in chunk:
|
|
||||||
chunks.append(ChatCompletionChunk(**c))
|
|
||||||
assert "choices" in c
|
|
||||||
if len(c["choices"]) == 1:
|
|
||||||
index = c["choices"][0]["index"]
|
|
||||||
assert index == 0
|
|
||||||
string += c["choices"][0]["delta"]["content"]
|
|
||||||
|
|
||||||
has_usage = c["usage"] is not None
|
|
||||||
assert not had_usage
|
|
||||||
if has_usage:
|
|
||||||
had_usage = True
|
|
||||||
elif c["usage"]:
|
|
||||||
had_usage = True
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Expected different payload: {c}")
|
|
||||||
assert had_usage
|
assert had_usage
|
||||||
assert (
|
assert (
|
||||||
string
|
string
|
||||||
@ -107,53 +89,29 @@ async def test_flash_llama_completion_stream_usage(
|
|||||||
)
|
)
|
||||||
assert chunks == response_snapshot
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
request = {
|
stream = client.chat_completion(
|
||||||
"model": "tgi",
|
model="tgi",
|
||||||
"messages": [
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "What is Deep Learning?",
|
"content": "What is Deep Learning?",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": 10,
|
max_tokens=10,
|
||||||
"temperature": 0.0,
|
temperature=0.0,
|
||||||
"stream": True,
|
# No usage
|
||||||
}
|
# stream_options={"include_usage": True},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
string = ""
|
string = ""
|
||||||
chunks = []
|
chunks = []
|
||||||
had_usage = False
|
had_usage = False
|
||||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
for chunk in stream:
|
||||||
async with session.post(url, json=request) as response:
|
chunks.append(chunk)
|
||||||
# iterate over the stream
|
assert chunk.usage is None
|
||||||
async for chunk in response.content.iter_any():
|
assert len(chunk.choices) == 1
|
||||||
# remove "data:"
|
assert chunk.choices[0].index == 0
|
||||||
chunk = chunk.decode().split("\n\n")
|
string += chunk.choices[0].delta.content
|
||||||
# remove "data:" if present
|
|
||||||
chunk = [c.replace("data:", "") for c in chunk]
|
|
||||||
# remove empty strings
|
|
||||||
chunk = [c for c in chunk if c]
|
|
||||||
# remove completion marking chunk
|
|
||||||
chunk = [c for c in chunk if c != " [DONE]"]
|
|
||||||
# parse json
|
|
||||||
chunk = [json.loads(c) for c in chunk]
|
|
||||||
|
|
||||||
for c in chunk:
|
|
||||||
chunks.append(ChatCompletionChunk(**c))
|
|
||||||
assert "choices" in c
|
|
||||||
if len(c["choices"]) == 1:
|
|
||||||
index = c["choices"][0]["index"]
|
|
||||||
assert index == 0
|
|
||||||
string += c["choices"][0]["delta"]["content"]
|
|
||||||
|
|
||||||
has_usage = c["usage"] is not None
|
|
||||||
assert not had_usage
|
|
||||||
if has_usage:
|
|
||||||
had_usage = True
|
|
||||||
elif c["usage"]:
|
|
||||||
had_usage = True
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Expected different payload")
|
|
||||||
assert not had_usage
|
|
||||||
assert (
|
assert (
|
||||||
string
|
string
|
||||||
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
Loading…
Reference in New Issue
Block a user