mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
315 lines
11 KiB
Python
315 lines
11 KiB
Python
import pytest
|
|
from huggingface_hub import InferenceClient
|
|
|
|
# to be removed when the InferenceClient client supports latest parameters
|
|
import requests
|
|
|
|
@pytest.fixture(scope="module")
|
|
def flash_llama_grammar_tools_handle(launcher):
|
|
with launcher(
|
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
num_shard=2,
|
|
disable_grammar_support=False,
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
|
await flash_llama_grammar_tools_handle.health(300)
|
|
return flash_llama_grammar_tools_handle.client
|
|
|
|
|
|
# All tests are based on the following model card
|
|
# https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_basic_gen(flash_llama_grammar_tools, response_snapshot):
|
|
client = InferenceClient(
|
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
|
)
|
|
|
|
output = client.chat.completions.create(
|
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the capital of France?",
|
|
},
|
|
],
|
|
stream=True,
|
|
seed=42,
|
|
max_tokens=20,
|
|
)
|
|
|
|
final_response = []
|
|
for chunk in output:
|
|
final_response.append(chunk.choices[0].delta.content)
|
|
resp = ''.join(final_response)
|
|
|
|
assert resp == "The capital of France is Paris."
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_code_interpreter_gen(flash_llama_grammar_tools, response_snapshot):
|
|
client = InferenceClient(
|
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
|
)
|
|
output = client.chat.completions.create(
|
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "Environment: ipython",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Write code to check if number is prime, use that to see if the number 7 is prime",
|
|
},
|
|
],
|
|
stream=True,
|
|
seed=42,
|
|
max_tokens=20,
|
|
)
|
|
|
|
final_response = []
|
|
for chunk in output:
|
|
final_response.append(chunk.choices[0].delta.content)
|
|
resp = ''.join(final_response)
|
|
|
|
assert resp == "def is_prime(n):\n if n <= 1:\n return False\n if n"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_code_builtin_tools_gen(flash_llama_grammar_tools, response_snapshot):
|
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
|
|
|
payload = {
|
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current weather in Menlo Park, California?",
|
|
}
|
|
],
|
|
"stream": False,
|
|
"seed": 42,
|
|
"max_tokens": 20,
|
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
|
}
|
|
|
|
response = requests.request("POST", url, json=payload)
|
|
response = response.json()
|
|
resp = response.get("choices")[0].get("message").get("content")
|
|
assert resp == "brave_search.call(query=\"current weather in Menlo Park, California\")"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_code_builtin_tools_explict_off_gen(flash_llama_grammar_tools, response_snapshot):
|
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
|
|
|
payload = {
|
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the current weather in Menlo Park, California?",
|
|
}
|
|
],
|
|
"stream": False,
|
|
"seed": 42,
|
|
"max_tokens": 20,
|
|
# "builtin_tools": ["brave_search", "wolfram_alpha"],
|
|
}
|
|
|
|
response = requests.request("POST", url, json=payload)
|
|
response = response.json()
|
|
resp = response.get("choices")[0].get("message").get("content")
|
|
assert resp == "I can't provide real-time weather information. However, I can encourage you to check a weather website"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_code_builtin_tools_two_gen(flash_llama_grammar_tools, response_snapshot):
|
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
|
|
|
payload = {
|
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0",
|
|
},
|
|
],
|
|
"stream": False,
|
|
"seed": 42,
|
|
"max_tokens": 50,
|
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
|
}
|
|
|
|
response = requests.request("POST", url, json=payload)
|
|
response = response.json()
|
|
resp = response.get("choices")[0].get("message").get("content")
|
|
assert resp == "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_code_builtin_tools_function_response_gen(flash_llama_grammar_tools, response_snapshot):
|
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
|
|
|
payload = {
|
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")",
|
|
},
|
|
{
|
|
"role": "ipython",
|
|
"content": "{\"queryresult\": {\"success\": true, \"inputstring\": \"solve x^3 - 4x^2 + 6x - 24 = 0\", \"pods\": [{\"title\": \"Input interpretation\", \"subpods\": [{\"title\": \"\", \"plaintext\": \"solve x^3 - 4 x^2 + 6 x - 24 = 0\"}]}, {\"title\": \"Results\", \"primary\": true, \"subpods\": [{\"title\": \"\", \"plaintext\": \"x = 4\"}, {\"title\": \"\", \"plaintext\": \"x = \u00b1 (i sqrt(6))\"}]}, ... ]}}",
|
|
},
|
|
],
|
|
"stream": False,
|
|
"seed": 42,
|
|
"max_tokens": 50,
|
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
|
}
|
|
|
|
response = requests.request("POST", url, json=payload)
|
|
response = response.json()
|
|
resp = response.get("choices")[0].get("message").get("content")
|
|
assert resp == "The solutions to the equation x^3 - 4x^2 + 6x - 24 = 0 are x = 4, x = i√6, and x = -i√6."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_user_supplied_json_tool_gen(flash_llama_grammar_tools, response_snapshot):
|
|
client = InferenceClient(
|
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
|
)
|
|
output = client.chat.completions.create(
|
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant with tool calling capabilities"
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Question: what is the weather like in San Fransisco?"
|
|
},
|
|
],
|
|
tools=[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_conditions",
|
|
"description": "Get the current weather conditions for a specific location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g., San Francisco, CA"
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["Celsius", "Fahrenheit"],
|
|
"description": "The temperature unit to use. Infer this from the user's location."
|
|
}
|
|
},
|
|
"required": ["location", "unit"]
|
|
}
|
|
}
|
|
}
|
|
],
|
|
stream=True,
|
|
seed=42,
|
|
max_tokens=50,
|
|
)
|
|
|
|
final_response = []
|
|
for chunk in output:
|
|
final_response.append(chunk.choices[0].delta.content)
|
|
resp = ''.join(final_response)
|
|
|
|
assert resp == "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_user_supplied_json_tool_function_response_gen(flash_llama_grammar_tools, response_snapshot):
|
|
client = InferenceClient(
|
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
|
)
|
|
output = client.chat.completions.create(
|
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Question: what is the weather like in San Fransisco?"
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}",
|
|
},
|
|
{
|
|
"role": "ipython",
|
|
"content": "{\"output\": \"Clouds giving way to sun Hi: 76° Tonight: Mainly clear early, then areas of low clouds forming Lo: 56°\"}",
|
|
},
|
|
],
|
|
tools=[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_conditions",
|
|
"description": "Get the current weather conditions for a specific location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g., San Francisco, CA"
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["Celsius", "Fahrenheit"],
|
|
"description": "The temperature unit to use. Infer this from the user's location."
|
|
}
|
|
},
|
|
"required": ["location", "unit"]
|
|
}
|
|
}
|
|
}
|
|
],
|
|
stream=True,
|
|
seed=42,
|
|
max_tokens=50,
|
|
)
|
|
|
|
final_response = []
|
|
for chunk in output:
|
|
final_response.append(chunk.choices[0].delta.content)
|
|
resp = ''.join(final_response)
|
|
assert resp == "The current weather conditions in San Francisco, CA are clouds giving way to sun with a high of 76°F and a low of 56°F." |