text-generation-inference/integration-tests/models/test_completion_prompts.py

276 lines
7.7 KiB
Python
Raw Normal View History

import pytest
import requests
from openai import OpenAI
from huggingface_hub import InferenceClient
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
await flash_llama_completion_handle.health(300)
return flash_llama_completion_handle.client
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "What is Deep Learning?",
"max_tokens": 10,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
max_tokens=10,
temperature=0.0,
stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
for chunk in stream:
# remove "data:"
chunks.append(chunk)
if len(chunk.choices) == 1:
index = chunk.choices[0].index
assert index == 0
string += chunk.choices[0].delta.content
if chunk.usage:
assert not had_usage
had_usage = True
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
stream = client.chat_completion(
model="tgi",
messages=[
{
"role": "user",
"content": "What is Deep Learning?",
}
],
max_tokens=10,
temperature=0.0,
# No usage
# stream_options={"include_usage": True},
stream=True,
)
string = ""
chunks = []
had_usage = False
for chunk in stream:
chunks.append(chunk)
assert chunk.usage is None
assert len(chunk.choices) == 1
assert chunk.choices[0].index == 0
string += chunk.choices[0].delta.content
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort()
all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.completions.create(
model="tgi",
prompt=[
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
max_tokens=10,
seed=0,
temperature=0.0,
stream=True,
)
strings = [""] * 4
chunks = []
for chunk in stream:
chunks.append(chunk)
index = chunk.choices[0].index
assert 0 <= index <= 4
strings[index] += chunk.choices[0].text
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_openai_usage(flash_llama_completion, response_snapshot):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat.completions.create(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": True},
)
chunks = []
for chunk in stream:
chunks.append(chunk)
for chunk in chunks[:-1]:
assert chunk.usage is None
for chunk in chunks[-1:]:
assert chunk.usage is not None
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_openai_nousage(flash_llama_completion, response_snapshot):
client = OpenAI(api_key="xx", base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat.completions.create(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": False},
)
chunks = []
for chunk in stream:
assert chunk.usage is None
chunks.append(chunk)
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_hfhub_usage(flash_llama_completion, response_snapshot):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": True},
)
chunks = []
for chunk in stream:
chunks.append(chunk)
for chunk in chunks[:-1]:
assert chunk.usage is None
for chunk in chunks[-1:]:
assert chunk.usage is not None
assert chunks == response_snapshot
@pytest.mark.release
async def test_chat_hfhub_nousage(flash_llama_completion, response_snapshot):
client = InferenceClient(base_url=f"{flash_llama_completion.base_url}/v1")
stream = client.chat_completion(
model="tgi",
messages=[{"role": "user", "content": "Say 'OK!'"}],
stream=True,
max_tokens=10,
seed=42,
stream_options={"include_usage": False},
)
chunks = []
for chunk in stream:
assert chunk.usage is None
chunks.append(chunk)
assert chunks == response_snapshot