mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 22:32:07 +00:00
141 lines
4.5 KiB
Python
141 lines
4.5 KiB
Python
from typing import Any, Dict
|
||
|
||
from text_generation import AsyncClient
|
||
import pytest
|
||
from Levenshtein import distance as levenshtein_distance
|
||
|
||
TEST_CONFIGS = {
|
||
"llama3-8b-shared-gaudi1": {
|
||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||
"input": "What is Deep Learning?",
|
||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use",
|
||
"args": [
|
||
"--sharded",
|
||
"true",
|
||
"--num-shard",
|
||
"8",
|
||
"--max-input-tokens",
|
||
"512",
|
||
"--max-total-tokens",
|
||
"1024",
|
||
"--max-batch-size",
|
||
"8",
|
||
"--max-batch-prefill-tokens",
|
||
"2048",
|
||
],
|
||
},
|
||
"llama3-8b-gaudi1": {
|
||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
||
"input": "What is Deep Learning?",
|
||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||
# "expected_sampling_output": " Part 2: Backpropagation\nIn my last post , I introduced the concept of deep learning and covered some high-level topics, including feedforward neural networks.",
|
||
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
||
"env_config": {},
|
||
"args": [
|
||
"--max-input-tokens",
|
||
"512",
|
||
"--max-total-tokens",
|
||
"1024",
|
||
"--max-batch-size",
|
||
"4",
|
||
"--max-batch-prefill-tokens",
|
||
"2048",
|
||
],
|
||
},
|
||
}
|
||
|
||
|
||
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
|
||
def test_config(request) -> Dict[str, Any]:
|
||
"""Fixture that provides model configurations for testing."""
|
||
test_config = TEST_CONFIGS[request.param]
|
||
test_config["test_name"] = request.param
|
||
return test_config
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def model_id(test_config):
|
||
yield test_config["model_id"]
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def test_name(test_config):
|
||
yield test_config["test_name"]
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def expected_outputs(test_config):
|
||
return {
|
||
"greedy": test_config["expected_greedy_output"],
|
||
# "sampling": model_config["expected_sampling_output"],
|
||
"batch": test_config["expected_batch_output"],
|
||
}
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def input(test_config):
|
||
return test_config["input"]
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def tgi_service(launcher, model_id, test_name):
|
||
with launcher(model_id, test_name) as tgi_service:
|
||
yield tgi_service
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
async def tgi_client(tgi_service) -> AsyncClient:
|
||
await tgi_service.health(1000)
|
||
return tgi_service.client
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_model_single_request(
|
||
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
|
||
):
|
||
# Bounded greedy decoding without input
|
||
response = await tgi_client.generate(
|
||
input,
|
||
max_new_tokens=32,
|
||
)
|
||
assert response.details.generated_tokens == 32
|
||
assert response.generated_text == expected_outputs["greedy"]
|
||
|
||
# TO FIX: Add sampling tests later, the gaudi backend seems to be flaky with the sampling even with a fixed seed
|
||
|
||
# Sampling
|
||
# response = await tgi_client.generate(
|
||
# "What is Deep Learning?",
|
||
# do_sample=True,
|
||
# top_k=50,
|
||
# top_p=0.9,
|
||
# repetition_penalty=1.2,
|
||
# max_new_tokens=100,
|
||
# seed=42,
|
||
# decoder_input_details=True,
|
||
# )
|
||
|
||
# assert expected_outputs["sampling"] in response.generated_text
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_model_multiple_requests(
|
||
tgi_client, generate_load, expected_outputs, input
|
||
):
|
||
num_requests = 4
|
||
responses = await generate_load(
|
||
tgi_client,
|
||
input,
|
||
max_new_tokens=32,
|
||
n=num_requests,
|
||
)
|
||
|
||
assert len(responses) == 4
|
||
expected = expected_outputs["batch"]
|
||
for r in responses:
|
||
assert r.details.generated_tokens == 32
|
||
# Compute the similarity with the expectation using the levenshtein distance
|
||
# We should not have more than two substitutions or additions
|
||
assert levenshtein_distance(r.generated_text, expected) < 3
|