text-generation-inference/integration-tests/models/test_completion_prompts.py
drbh dc5f05f8e6
Pr 3003 ci branch (#3007)
* change ChatCompletionChunk to align with "OpenAI Chat Completions streaming API"

Moving after tool_calls2

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

add in Buffering..

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

fix: handle usage outside of stream state and add tests

Simplifying everything quite a bit.

Remove the unused model_dump.

Clippy.

Clippy ?

Ruff.

Uppgrade the flake for latest transformers.

Upgrade after rebase.

Remove potential footgun.

Fix completion test.

* Clippy.

* Tweak for multi prompt.

* Ruff.

* Update the snapshot a bit.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-03-10 17:56:19 +01:00

276 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
import requests
from openai import OpenAI
from huggingface_hub import InferenceClient
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
await flash_llama_completion_handle.health(300)
return flash_llama_completion_handle.client
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "What is Deep Learning?",
"max_tokens": 10,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
max_tokens=10,
temperature=0.0,
stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
for chunk in stream:
# remove "data:"
chunks.append(chunk)
if len(chunk.choices) == 1:
index = chunk.choices[0].index
assert index == 0
string += chunk.choices[0].delta.content
if chunk.usage:
assert not had_usage
had_usage = True
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
max_tokens=10,
temperature=0.0,
# No usage
# stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
for chunk in stream:
chunks.append(chunk)
assert chunk.usage is None
assert len(chunk.choices) == 1
assert chunk.choices[0].index == 0
string += chunk.choices[0].delta.content
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort()
all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.completions.create(
model="tgi",
prompt=[
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
max_tokens=10,
seed=0,
temperature=0.0,
stream=True,
)
strings = [""] * 4
chunks = []
for chunk in stream:
chunks.append(chunk)
index = chunk.choices[0].index
assert 0 <= index <= 4
strings[index] += chunk.choices[0].text
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_openai_usage(flash_llama_completion, response_snapshot):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat.completions.create(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": True},
)
chunks = []
for chunk in stream:
chunks.append(chunk)
for chunk in chunks[:-1]:
assert chunk.usage is None
for chunk in chunks[-1:]:
assert chunk.usage is not None
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_openai_nousage(flash_llama_completion, response_snapshot):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat.completions.create(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": False},
)
chunks = []
for chunk in stream:
assert chunk.usage is None
chunks.append(chunk)
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_hfhub_usage(flash_llama_completion, response_snapshot):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": True},
)
chunks = []
for chunk in stream:
chunks.append(chunk)
for chunk in chunks[:-1]:
assert chunk.usage is None
for chunk in chunks[-1:]:
assert chunk.usage is not None
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_hfhub_nousage(flash_llama_completion, response_snapshot):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": False},
)
chunks = []
for chunk in stream:
assert chunk.usage is None
chunks.append(chunk)
assert chunks == response_snapshot