mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-24 10:00:16 +00:00
fix broken test
This commit is contained in:
parent
1f03afe94d
commit
0295bf243f
@ -50,8 +50,6 @@ local-dev-install: install-dependencies
|
|||||||
|
|
||||||
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
|
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
|
||||||
run-integration-tests:
|
run-integration-tests:
|
||||||
pip install -U pip uv
|
|
||||||
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
|
|
||||||
DOCKER_VOLUME=${root_dir}/data \
|
DOCKER_VOLUME=${root_dir}/data \
|
||||||
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
|
||||||
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi
|
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi
|
||||||
|
@ -99,6 +99,11 @@ curl 127.0.0.1:8080/generate \
|
|||||||
|
|
||||||
### Integration tests
|
### Integration tests
|
||||||
|
|
||||||
|
Install the dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r integration-tests/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
To run the integration tests, you need to first build the image:
|
To run the integration tests, you need to first build the image:
|
||||||
```bash
|
```bash
|
||||||
make -C backends/gaudi image
|
make -C backends/gaudi image
|
||||||
|
@ -16,8 +16,7 @@ from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
|||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
import logging
|
import logging
|
||||||
from gaudi.test_gaudi_generate import TEST_CONFIGS
|
from gaudi.test_gaudi_generate import TEST_CONFIGS
|
||||||
from text_generation import AsyncClient
|
from huggingface_hub import AsyncInferenceClient, TextGenerationOutput
|
||||||
from text_generation.types import Response
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -71,9 +70,15 @@ def stream_container_logs(container, test_name):
|
|||||||
logger.error(f"Error streaming container logs: {str(e)}")
|
logger.error(f"Error streaming container logs: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestClient(AsyncInferenceClient):
|
||||||
|
def __init__(self, service_name: str, base_url: str):
|
||||||
|
super().__init__(model=base_url)
|
||||||
|
self.service_name = service_name
|
||||||
|
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, service_name: str, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)
|
self.client = TestClient(service_name, f"http://localhost:{port}")
|
||||||
|
|
||||||
def _inner_health(self):
|
def _inner_health(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -89,7 +94,7 @@ class LauncherHandle:
|
|||||||
raise RuntimeError("Launcher crashed")
|
raise RuntimeError("Launcher crashed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.client.generate("test")
|
await self.client.text_generation("test", max_new_tokens=1)
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
logger.info(f"Health check passed after {elapsed:.1f}s")
|
logger.info(f"Health check passed after {elapsed:.1f}s")
|
||||||
return
|
return
|
||||||
@ -113,7 +118,8 @@ class LauncherHandle:
|
|||||||
|
|
||||||
class ContainerLauncherHandle(LauncherHandle):
|
class ContainerLauncherHandle(LauncherHandle):
|
||||||
def __init__(self, docker_client, container_name, port: int):
|
def __init__(self, docker_client, container_name, port: int):
|
||||||
super(ContainerLauncherHandle, self).__init__(port)
|
service_name = container_name # Use container name as service name
|
||||||
|
super(ContainerLauncherHandle, self).__init__(service_name, port)
|
||||||
self.docker_client = docker_client
|
self.docker_client = docker_client
|
||||||
self.container_name = container_name
|
self.container_name = container_name
|
||||||
|
|
||||||
@ -134,7 +140,8 @@ class ContainerLauncherHandle(LauncherHandle):
|
|||||||
|
|
||||||
class ProcessLauncherHandle(LauncherHandle):
|
class ProcessLauncherHandle(LauncherHandle):
|
||||||
def __init__(self, process, port: int):
|
def __init__(self, process, port: int):
|
||||||
super(ProcessLauncherHandle, self).__init__(port)
|
service_name = "process" # Use generic name for process launcher
|
||||||
|
super(ProcessLauncherHandle, self).__init__(service_name, port)
|
||||||
self.process = process
|
self.process = process
|
||||||
|
|
||||||
def _inner_health(self) -> bool:
|
def _inner_health(self) -> bool:
|
||||||
@ -153,11 +160,13 @@ def data_volume():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gaudi_launcher(event_loop):
|
def gaudi_launcher():
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def docker_launcher(
|
def docker_launcher(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
test_name: str,
|
test_name: str,
|
||||||
|
tgi_args: List[str] = None,
|
||||||
|
env_config: dict = None
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Starting docker launcher for model {model_id} and test {test_name}"
|
f"Starting docker launcher for model {model_id} and test {test_name}"
|
||||||
@ -185,22 +194,29 @@ def gaudi_launcher(event_loop):
|
|||||||
)
|
)
|
||||||
container.stop()
|
container.stop()
|
||||||
container.wait()
|
container.wait()
|
||||||
|
container.remove()
|
||||||
|
logger.info(f"Removed existing container {container_name}")
|
||||||
except NotFound:
|
except NotFound:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling existing container: {str(e)}")
|
logger.error(f"Error handling existing container: {str(e)}")
|
||||||
|
|
||||||
tgi_args = TEST_CONFIGS[test_name]["args"].copy()
|
if tgi_args is None:
|
||||||
|
tgi_args = []
|
||||||
|
else:
|
||||||
|
tgi_args = tgi_args.copy()
|
||||||
|
|
||||||
env = BASE_ENV.copy()
|
env = BASE_ENV.copy()
|
||||||
|
|
||||||
# Add model_id to env
|
# Add model_id to env
|
||||||
env["MODEL_ID"] = model_id
|
env["MODEL_ID"] = model_id
|
||||||
|
|
||||||
# Add env config that is definied in the fixture parameter
|
# Add env config that is defined in the fixture parameter
|
||||||
if "env_config" in TEST_CONFIGS[test_name]:
|
if env_config is not None:
|
||||||
env.update(TEST_CONFIGS[test_name]["env_config"].copy())
|
env.update(env_config.copy())
|
||||||
|
|
||||||
|
volumes = []
|
||||||
|
if DOCKER_VOLUME:
|
||||||
volumes = [f"{DOCKER_VOLUME}:/data"]
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||||
logger.debug(f"Using volume {volumes}")
|
logger.debug(f"Using volume {volumes}")
|
||||||
|
|
||||||
@ -276,13 +292,14 @@ def gaudi_launcher(event_loop):
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gaudi_generate_load():
|
def gaudi_generate_load():
|
||||||
async def generate_load_inner(
|
async def generate_load_inner(
|
||||||
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
client: AsyncInferenceClient, prompt: str, max_new_tokens: int, n: int
|
||||||
) -> List[Response]:
|
) -> List[TextGenerationOutput]:
|
||||||
try:
|
try:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(
|
client.text_generation(
|
||||||
prompt,
|
prompt,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
details=True,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
for _ in range(n)
|
for _ in range(n)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import Any, Dict, Generator
|
from typing import Any, Dict, Generator
|
||||||
from _pytest.fixtures import SubRequest
|
from _pytest.fixtures import SubRequest
|
||||||
|
from huggingface_hub import AsyncInferenceClient, TextGenerationOutput
|
||||||
from text_generation import AsyncClient
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -238,13 +237,18 @@ def input(test_config: Dict[str, Any]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def tgi_service(gaudi_launcher, model_id: str, test_name: str):
|
def tgi_service(gaudi_launcher, model_id: str, test_name: str, test_config: Dict[str, Any]):
|
||||||
with gaudi_launcher(model_id, test_name) as tgi_service:
|
with gaudi_launcher(
|
||||||
|
model_id,
|
||||||
|
test_name,
|
||||||
|
tgi_args=test_config.get("args", []),
|
||||||
|
env_config=test_config.get("env_config", {})
|
||||||
|
) as tgi_service:
|
||||||
yield tgi_service
|
yield tgi_service
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def tgi_client(tgi_service) -> AsyncClient:
|
async def tgi_client(tgi_service) -> AsyncInferenceClient:
|
||||||
await tgi_service.health(1000)
|
await tgi_service.health(1000)
|
||||||
return tgi_service.client
|
return tgi_service.client
|
||||||
|
|
||||||
@ -252,12 +256,14 @@ async def tgi_client(tgi_service) -> AsyncClient:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.all_models
|
@pytest.mark.all_models
|
||||||
async def test_model_single_request(
|
async def test_model_single_request(
|
||||||
tgi_client: AsyncClient, expected_outputs: Dict[str, str], input: str
|
tgi_client: AsyncInferenceClient, expected_outputs: Dict[str, str], input: str
|
||||||
):
|
):
|
||||||
# Bounded greedy decoding without input
|
# Bounded greedy decoding without input
|
||||||
response = await tgi_client.generate(
|
response = await tgi_client.text_generation(
|
||||||
input,
|
input,
|
||||||
max_new_tokens=32,
|
max_new_tokens=32,
|
||||||
|
details=True,
|
||||||
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 32
|
assert response.details.generated_tokens == 32
|
||||||
assert response.generated_text == expected_outputs["greedy"]
|
assert response.generated_text == expected_outputs["greedy"]
|
||||||
@ -266,7 +272,7 @@ async def test_model_single_request(
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.all_models
|
@pytest.mark.all_models
|
||||||
async def test_model_multiple_requests(
|
async def test_model_multiple_requests(
|
||||||
tgi_client: AsyncClient,
|
tgi_client: AsyncInferenceClient,
|
||||||
gaudi_generate_load,
|
gaudi_generate_load,
|
||||||
expected_outputs: Dict[str, str],
|
expected_outputs: Dict[str, str],
|
||||||
input: str,
|
input: str,
|
||||||
|
@ -1,259 +0,0 @@
|
|||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
|
|
||||||
TEST_CONFIGS = {
|
|
||||||
"meta-llama/Llama-3.1-8B-Instruct-shared": {
|
|
||||||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
|
||||||
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
|
||||||
"args": [
|
|
||||||
"--sharded",
|
|
||||||
"true",
|
|
||||||
"--num-shard",
|
|
||||||
"8",
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"8",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"meta-llama/Llama-3.1-8B-Instruct": {
|
|
||||||
"model_id": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
|
||||||
"expected_batch_output": " A Beginner’s Guide\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of",
|
|
||||||
"env_config": {},
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"meta-llama/Llama-2-7b-chat-hf": {
|
|
||||||
"model_id": "meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
|
||||||
"expected_batch_output": "\n\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\u2014specific",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3": {
|
|
||||||
"model_id": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
|
||||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"bigcode/starcoder2-3b": {
|
|
||||||
"model_id": "bigcode/starcoder2-3b",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
|
||||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\n\nNeural networks are a type of machine learning algorithm that",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"google/gemma-7b-it": {
|
|
||||||
"model_id": "google/gemma-7b-it",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
|
||||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Neural networks are inspired by the structure and function of",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"Qwen/Qwen2-0.5B-Instruct": {
|
|
||||||
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
|
||||||
"expected_batch_output": " Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
"--max-batch-prefill-tokens",
|
|
||||||
"2048",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"tiiuae/falcon-7b-instruct": {
|
|
||||||
"model_id": "tiiuae/falcon-7b-instruct",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
|
||||||
"expected_batch_output": "\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"microsoft/phi-1_5": {
|
|
||||||
"model_id": "microsoft/phi-1_5",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
|
||||||
"expected_batch_output": "\n\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"openai-community/gpt2": {
|
|
||||||
"model_id": "openai-community/gpt2",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
|
||||||
"expected_batch_output": "\n\nDeep learning is a new field of research that has been around for a long time. It is a new field of research that has been around for a",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"EleutherAI/gpt-j-6b": {
|
|
||||||
"model_id": "EleutherAI/gpt-j-6b",
|
|
||||||
"input": "What is Deep Learning?",
|
|
||||||
"expected_greedy_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
|
||||||
"expected_batch_output": "\n\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by",
|
|
||||||
"args": [
|
|
||||||
"--max-input-tokens",
|
|
||||||
"512",
|
|
||||||
"--max-total-tokens",
|
|
||||||
"1024",
|
|
||||||
"--max-batch-size",
|
|
||||||
"4",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Testing {len(TEST_CONFIGS)} models")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=TEST_CONFIGS.keys())
|
|
||||||
def test_config(request) -> Dict[str, Any]:
|
|
||||||
"""Fixture that provides model configurations for testing."""
|
|
||||||
test_config = TEST_CONFIGS[request.param]
|
|
||||||
test_config["test_name"] = request.param
|
|
||||||
return test_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def model_id(test_config):
|
|
||||||
yield test_config["model_id"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def test_name(test_config):
|
|
||||||
yield test_config["test_name"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def expected_outputs(test_config):
|
|
||||||
return {
|
|
||||||
"greedy": test_config["expected_greedy_output"],
|
|
||||||
# "sampling": model_config["expected_sampling_output"],
|
|
||||||
"batch": test_config["expected_batch_output"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def input(test_config):
|
|
||||||
return test_config["input"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def tgi_service(launcher, model_id, test_name):
|
|
||||||
with launcher(model_id, test_name) as tgi_service:
|
|
||||||
yield tgi_service
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
async def tgi_client(tgi_service) -> AsyncClient:
|
|
||||||
await tgi_service.health(1000)
|
|
||||||
return tgi_service.client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_single_request(
|
|
||||||
tgi_client: AsyncClient, expected_outputs: Dict[str, Any], input: str
|
|
||||||
):
|
|
||||||
# Bounded greedy decoding without input
|
|
||||||
response = await tgi_client.generate(
|
|
||||||
input,
|
|
||||||
max_new_tokens=32,
|
|
||||||
)
|
|
||||||
assert response.details.generated_tokens == 32
|
|
||||||
assert response.generated_text == expected_outputs["greedy"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_multiple_requests(
|
|
||||||
tgi_client, generate_load, expected_outputs, input
|
|
||||||
):
|
|
||||||
num_requests = 4
|
|
||||||
responses = await generate_load(
|
|
||||||
tgi_client,
|
|
||||||
input,
|
|
||||||
max_new_tokens=32,
|
|
||||||
n=num_requests,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(responses) == 4
|
|
||||||
expected = expected_outputs["batch"]
|
|
||||||
for r in responses:
|
|
||||||
assert r.details.generated_tokens == 32
|
|
||||||
assert r.generated_text == expected
|
|
Loading…
Reference in New Issue
Block a user