mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
* change ChatCompletionChunk to align with "OpenAI Chat Completions streaming API" Moving after tool_calls2 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> add in Buffering.. Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> fix: handle usage outside of stream state and add tests Simplifying everything quite a bit. Remove the unused model_dump. Clippy. Clippy ? Ruff. Uppgrade the flake for latest transformers. Upgrade after rebase. Remove potential footgun. Fix completion test. * Clippy. * Tweak for multi prompt. * Ruff. * Update the snapshot a bit. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
276 lines
7.7 KiB
Python
276 lines
7.7 KiB
Python
import pytest
|
||
import requests
|
||
from openai import OpenAI
|
||
from huggingface_hub import InferenceClient
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def flash_llama_completion_handle(launcher):
|
||
with launcher(
|
||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||
) as handle:
|
||
yield handle
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
async def flash_llama_completion(flash_llama_completion_handle):
|
||
await flash_llama_completion_handle.health(300)
|
||
return flash_llama_completion_handle.client
|
||
|
||
|
||
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
||
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||
|
||
|
||
@pytest.mark.release
|
||
def test_flash_llama_completion_single_prompt(
|
||
flash_llama_completion, response_snapshot
|
||
):
|
||
response = requests.post(
|
||
f"{flash_llama_completion.base_url}/v1/completions",
|
||
json={
|
||
"model": "tgi",
|
||
"prompt": "What is Deep Learning?",
|
||
"max_tokens": 10,
|
||
"temperature": 0.0,
|
||
},
|
||
headers=flash_llama_completion.headers,
|
||
stream=False,
|
||
)
|
||
response = response.json()
|
||
assert len(response["choices"]) == 1
|
||
assert (
|
||
response["choices"][0]["text"]
|
||
== " A Beginner’s Guide\nDeep learning is a subset"
|
||
)
|
||
assert response == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_flash_llama_completion_stream_usage(
|
||
flash_llama_completion, response_snapshot
|
||
):
|
||
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
|
||
stream = client.chat_completion(
|
||
model="tgi",
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": "What is Deep Learning?",
|
||
}
|
||
],
|
||
max_tokens=10,
|
||
temperature=0.0,
|
||
stream_options={"include_usage": True},
|
||
stream=True,
|
||
)
|
||
string = ""
|
||
chunks = []
|
||
had_usage = False
|
||
for chunk in stream:
|
||
# remove "data:"
|
||
chunks.append(chunk)
|
||
if len(chunk.choices) == 1:
|
||
index = chunk.choices[0].index
|
||
assert index == 0
|
||
string += chunk.choices[0].delta.content
|
||
if chunk.usage:
|
||
assert not had_usage
|
||
had_usage = True
|
||
|
||
assert had_usage
|
||
assert (
|
||
string
|
||
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||
)
|
||
assert chunks == response_snapshot
|
||
|
||
stream = client.chat_completion(
|
||
model="tgi",
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": "What is Deep Learning?",
|
||
}
|
||
],
|
||
max_tokens=10,
|
||
temperature=0.0,
|
||
# No usage
|
||
# stream_options={"include_usage": True},
|
||
stream=True,
|
||
)
|
||
string = ""
|
||
chunks = []
|
||
had_usage = False
|
||
for chunk in stream:
|
||
chunks.append(chunk)
|
||
assert chunk.usage is None
|
||
assert len(chunk.choices) == 1
|
||
assert chunk.choices[0].index == 0
|
||
string += chunk.choices[0].delta.content
|
||
assert (
|
||
string
|
||
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||
)
|
||
|
||
|
||
@pytest.mark.release
|
||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||
response = requests.post(
|
||
f"{flash_llama_completion.base_url}/v1/completions",
|
||
json={
|
||
"model": "tgi",
|
||
"prompt": [
|
||
"What is Deep Learning?",
|
||
"Is water wet?",
|
||
"What is the capital of France?",
|
||
"def mai",
|
||
],
|
||
"max_tokens": 10,
|
||
"seed": 0,
|
||
"temperature": 0.0,
|
||
},
|
||
headers=flash_llama_completion.headers,
|
||
stream=False,
|
||
)
|
||
response = response.json()
|
||
assert len(response["choices"]) == 4
|
||
|
||
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
|
||
all_indexes.sort()
|
||
all_indices, all_strings = zip(*all_indexes)
|
||
assert list(all_indices) == [0, 1, 2, 3]
|
||
assert list(all_strings) == [
|
||
" A Beginner’s Guide\nDeep learning is a subset",
|
||
" This is a question that has puzzled many people for",
|
||
" Paris\nWhat is the capital of France?\nThe",
|
||
'usculas_minusculas(s):\n """\n',
|
||
]
|
||
|
||
assert response == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_flash_llama_completion_many_prompts_stream(
|
||
flash_llama_completion, response_snapshot
|
||
):
|
||
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
|
||
stream = client.completions.create(
|
||
model="tgi",
|
||
prompt=[
|
||
"What is Deep Learning?",
|
||
"Is water wet?",
|
||
"What is the capital of France?",
|
||
"def mai",
|
||
],
|
||
max_tokens=10,
|
||
seed=0,
|
||
temperature=0.0,
|
||
stream=True,
|
||
)
|
||
|
||
strings = [""] * 4
|
||
chunks = []
|
||
for chunk in stream:
|
||
chunks.append(chunk)
|
||
index = chunk.choices[0].index
|
||
assert 0 <= index <= 4
|
||
strings[index] += chunk.choices[0].text
|
||
|
||
assert list(strings) == [
|
||
" A Beginner’s Guide\nDeep learning is a subset",
|
||
" This is a question that has puzzled many people for",
|
||
" Paris\nWhat is the capital of France?\nThe",
|
||
'usculas_minusculas(s):\n """\n',
|
||
]
|
||
assert chunks == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_chat_openai_usage(flash_llama_completion, response_snapshot):
|
||
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
|
||
|
||
stream = client.chat.completions.create(
|
||
model="tgi",
|
||
messages=[{"role": "user", "content": "Say 'OK!'"}],
|
||
stream=True,
|
||
max_tokens=10,
|
||
seed=42,
|
||
stream_options={"include_usage": True},
|
||
)
|
||
|
||
chunks = []
|
||
for chunk in stream:
|
||
chunks.append(chunk)
|
||
for chunk in chunks[:-1]:
|
||
assert chunk.usage is None
|
||
for chunk in chunks[-1:]:
|
||
assert chunk.usage is not None
|
||
|
||
assert chunks == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_chat_openai_nousage(flash_llama_completion, response_snapshot):
|
||
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
|
||
|
||
stream = client.chat.completions.create(
|
||
model="tgi",
|
||
messages=[{"role": "user", "content": "Say 'OK!'"}],
|
||
stream=True,
|
||
max_tokens=10,
|
||
seed=42,
|
||
stream_options={"include_usage": False},
|
||
)
|
||
|
||
chunks = []
|
||
for chunk in stream:
|
||
assert chunk.usage is None
|
||
chunks.append(chunk)
|
||
|
||
assert chunks == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_chat_hfhub_usage(flash_llama_completion, response_snapshot):
|
||
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
|
||
stream = client.chat_completion(
|
||
model="tgi",
|
||
messages=[{"role": "user", "content": "Say 'OK!'"}],
|
||
stream=True,
|
||
max_tokens=10,
|
||
seed=42,
|
||
stream_options={"include_usage": True},
|
||
)
|
||
|
||
chunks = []
|
||
for chunk in stream:
|
||
chunks.append(chunk)
|
||
|
||
for chunk in chunks[:-1]:
|
||
assert chunk.usage is None
|
||
for chunk in chunks[-1:]:
|
||
assert chunk.usage is not None
|
||
|
||
assert chunks == response_snapshot
|
||
|
||
|
||
@pytest.mark.release
|
||
async def test_chat_hfhub_nousage(flash_llama_completion, response_snapshot):
|
||
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
|
||
stream = client.chat_completion(
|
||
model="tgi",
|
||
messages=[{"role": "user", "content": "Say 'OK!'"}],
|
||
stream=True,
|
||
max_tokens=10,
|
||
seed=42,
|
||
stream_options={"include_usage": False},
|
||
)
|
||
|
||
chunks = []
|
||
for chunk in stream:
|
||
assert chunk.usage is None
|
||
chunks.append(chunk)
|
||
|
||
assert chunks == response_snapshot
|