mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 16:12:12 +00:00
Update the old test.
This commit is contained in:
parent
062be12812
commit
17a9bb962e
@ -12,7 +12,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338471,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -32,7 +32,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338471,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -52,7 +52,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338471,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -72,7 +72,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338471,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -92,7 +92,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -112,7 +112,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -132,7 +132,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -152,7 +152,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -172,7 +172,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
@ -192,16 +192,7 @@
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1741274364,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
"system_fingerprint": "3.1.2-dev0-native",
|
||||
"usage": null
|
||||
},
|
||||
{
|
||||
"choices": [],
|
||||
"created": 1741274364,
|
||||
"created": 1741338472,
|
||||
"id": "",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"object": "chat.completion.chunk",
|
||||
|
@ -2,8 +2,9 @@ import pytest
|
||||
import requests
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from text_generation.types import Completion, ChatCompletionChunk
|
||||
from text_generation.types import Completion
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -52,54 +53,35 @@ def test_flash_llama_completion_single_prompt(
|
||||
async def test_flash_llama_completion_stream_usage(
|
||||
flash_llama_completion, response_snapshot
|
||||
):
|
||||
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
|
||||
stream = client.chat_completion(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is Deep Learning?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
"stream_options": {"include_usage": True},
|
||||
"stream": True,
|
||||
}
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
stream_options={"include_usage": True},
|
||||
stream=True,
|
||||
)
|
||||
string = ""
|
||||
chunks = []
|
||||
had_usage = False
|
||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||
async with session.post(url, json=request) as response:
|
||||
# iterate over the stream
|
||||
async for chunk in response.content.iter_any():
|
||||
# remove "data:"
|
||||
chunk = chunk.decode().split("\n\n")
|
||||
# remove "data:" if present
|
||||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# remove completion marking chunk
|
||||
chunk = [c for c in chunk if c != " [DONE]"]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
for chunk in stream:
|
||||
# remove "data:"
|
||||
chunks.append(chunk)
|
||||
print(f"Chunk {chunk}")
|
||||
if len(chunk.choices) == 1:
|
||||
index = chunk.choices[0].index
|
||||
assert index == 0
|
||||
string += chunk.choices[0].delta.content
|
||||
if chunk.usage:
|
||||
assert not had_usage
|
||||
had_usage = True
|
||||
|
||||
for c in chunk:
|
||||
chunks.append(ChatCompletionChunk(**c))
|
||||
assert "choices" in c
|
||||
if len(c["choices"]) == 1:
|
||||
index = c["choices"][0]["index"]
|
||||
assert index == 0
|
||||
string += c["choices"][0]["delta"]["content"]
|
||||
|
||||
has_usage = c["usage"] is not None
|
||||
assert not had_usage
|
||||
if has_usage:
|
||||
had_usage = True
|
||||
elif c["usage"]:
|
||||
had_usage = True
|
||||
else:
|
||||
raise RuntimeError(f"Expected different payload: {c}")
|
||||
assert had_usage
|
||||
assert (
|
||||
string
|
||||
@ -107,53 +89,29 @@ async def test_flash_llama_completion_stream_usage(
|
||||
)
|
||||
assert chunks == response_snapshot
|
||||
|
||||
request = {
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
stream = client.chat_completion(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is Deep Learning?",
|
||||
}
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
"stream": True,
|
||||
}
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
# No usage
|
||||
# stream_options={"include_usage": True},
|
||||
stream=True,
|
||||
)
|
||||
string = ""
|
||||
chunks = []
|
||||
had_usage = False
|
||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||
async with session.post(url, json=request) as response:
|
||||
# iterate over the stream
|
||||
async for chunk in response.content.iter_any():
|
||||
# remove "data:"
|
||||
chunk = chunk.decode().split("\n\n")
|
||||
# remove "data:" if present
|
||||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# remove completion marking chunk
|
||||
chunk = [c for c in chunk if c != " [DONE]"]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
|
||||
for c in chunk:
|
||||
chunks.append(ChatCompletionChunk(**c))
|
||||
assert "choices" in c
|
||||
if len(c["choices"]) == 1:
|
||||
index = c["choices"][0]["index"]
|
||||
assert index == 0
|
||||
string += c["choices"][0]["delta"]["content"]
|
||||
|
||||
has_usage = c["usage"] is not None
|
||||
assert not had_usage
|
||||
if has_usage:
|
||||
had_usage = True
|
||||
elif c["usage"]:
|
||||
had_usage = True
|
||||
else:
|
||||
raise RuntimeError("Expected different payload")
|
||||
assert not had_usage
|
||||
for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert chunk.usage is None
|
||||
assert len(chunk.choices) == 1
|
||||
assert chunk.choices[0].index == 0
|
||||
string += chunk.choices[0].delta.content
|
||||
assert (
|
||||
string
|
||||
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||
|
Loading…
Reference in New Issue
Block a user