wip(test): adding test to ci

This commit is contained in:
baptiste 2025-04-10 07:46:59 +00:00
parent 8f8819795f
commit 918b29a0af
8 changed files with 30 additions and 38 deletions

View File

@ -129,9 +129,9 @@ jobs:
export label_extension="-gaudi" export label_extension="-gaudi"
export docker_volume="/mnt/cache" export docker_volume="/mnt/cache"
export docker_devices="" export docker_devices=""
export runs_on="ubuntu-latest" export runs_on="aws-dl1-24xlarge"
export platform="" export platform=""
export extra_pytest="" export extra_pytest="--gaudi"
export target="" export target=""
esac esac
echo $dockerfile echo $dockerfile

View File

@ -50,10 +50,9 @@ local-dev-install: install-dependencies
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image) # In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests: run-integration-tests:
uv pip install -r ${root_dir}/backends/gaudi/server/integration-tests/requirements.txt
DOCKER_VOLUME=${root_dir}/data \ DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \ HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests pytest --durations=0 -s -vv integration-tests --gaudi
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests # This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests: capture-expected-outputs-for-integration-tests:

View File

@ -1,2 +0,0 @@
[pytest]
asyncio_mode = auto

View File

@ -1,7 +0,0 @@
pytest >= 8.3.5
pytest-asyncio >= 0.26.0
docker >= 7.1.0
Levenshtein >= 0.27.1
loguru >= 0.7.3
aiohttp >= 3.11.14
text-generation

View File

@ -47,7 +47,6 @@ from text_generation.types import (
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
Completion,
Details, Details,
Grammar, Grammar,
InputToken, InputToken,
@ -68,6 +67,9 @@ def pytest_addoption(parser):
parser.addoption( parser.addoption(
"--neuron", action="store_true", default=False, help="run neuron tests" "--neuron", action="store_true", default=False, help="run neuron tests"
) )
parser.addoption(
"--gaudi", action="store_true", default=False, help="run gaudi tests"
)
def pytest_configure(config): def pytest_configure(config):
@ -84,6 +86,14 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(pytest.mark.skip(reason="need --release option to run")) item.add_marker(pytest.mark.skip(reason="need --release option to run"))
selectors.append(skip_release) selectors.append(skip_release)
if config.getoption("--gaudi"):
def skip_not_gaudi(item):
if "gaudi" not in item.keywords:
item.add_marker(pytest.mark.skip(reason="requires --gaudi to run"))
selectors.append(skip_not_gaudi)
if config.getoption("--neuron"): if config.getoption("--neuron"):
def skip_not_neuron(item): def skip_not_neuron(item):
@ -99,7 +109,12 @@ def pytest_collection_modifyitems(config, items):
if "neuron" in item.keywords: if "neuron" in item.keywords:
item.add_marker(pytest.mark.skip(reason="requires --neuron to run")) item.add_marker(pytest.mark.skip(reason="requires --neuron to run"))
def skip_gaudi(item):
if "gaudi" in item.keywords:
item.add_marker(pytest.mark.skip(reason="requires --gaudi to run"))
selectors.append(skip_neuron) selectors.append(skip_neuron)
selectors.append(skip_gaudi)
for item in items: for item in items:
for selector in selectors: for selector in selectors:
selector(item) selector(item)
@ -131,7 +146,6 @@ class ResponseComparator(JSONSnapshotExtension):
or isinstance(data, ChatComplete) or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk) or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete) or isinstance(data, ChatCompletionComplete)
or isinstance(data, Completion)
or isinstance(data, OAIChatCompletionChunk) or isinstance(data, OAIChatCompletionChunk)
or isinstance(data, OAICompletion) or isinstance(data, OAICompletion)
): ):
@ -188,8 +202,6 @@ class ResponseComparator(JSONSnapshotExtension):
if isinstance(choices, List) and len(choices) >= 1: if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]: if "delta" in choices[0]:
return ChatCompletionChunk(**data) return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data) return ChatComplete(**data)
else: else:
return Response(**data) return Response(**data)
@ -282,9 +294,6 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return ( return (
response.choices[0].message.content == other.choices[0].message.content response.choices[0].message.content == other.choices[0].message.content
@ -329,11 +338,6 @@ class ResponseComparator(JSONSnapshotExtension):
if len(serialized_data) == 0: if len(serialized_data) == 0:
return len(snapshot_data) == len(serialized_data) return len(snapshot_data) == len(serialized_data)
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete): if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]

View File

@ -14,11 +14,18 @@ import docker
import pytest import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound from docker.errors import NotFound
from loguru import logger import logging
from test_model import TEST_CONFIGS from gaudi.test_generate import TEST_CONFIGS
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import Response from text_generation.types import Response
logging.basicConfig(
level=logging.INFO,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
stream=sys.stdout,
)
logger = logging.getLogger(__file__)
# Use the latest image from the local docker build # Use the latest image from the local docker build
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi") DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tgi-gaudi")
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None)
@ -48,12 +55,6 @@ HABANA_RUN_ARGS = {
"cap_add": ["sys_nice"], "cap_add": ["sys_nice"],
} }
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="INFO",
)
def stream_container_logs(container, test_name): def stream_container_logs(container, test_name):
"""Stream container logs in a separate thread.""" """Stream container logs in a separate thread."""

View File

@ -3,7 +3,7 @@ import os
from typing import Dict, Any, Generator from typing import Dict, Any, Generator
import pytest import pytest
from test_model import TEST_CONFIGS from test_generate import TEST_CONFIGS
UNKNOWN_CONFIGS = { UNKNOWN_CONFIGS = {
name: config name: config

View File

@ -2,7 +2,6 @@ from typing import Any, Dict
from text_generation import AsyncClient from text_generation import AsyncClient
import pytest import pytest
from Levenshtein import distance as levenshtein_distance
# The "args" config is not optimized for speed but only check that the inference is working for the different models architectures # The "args" config is not optimized for speed but only check that the inference is working for the different models architectures
TEST_CONFIGS = { TEST_CONFIGS = {
@ -271,6 +270,4 @@ async def test_model_multiple_requests(
expected = expected_outputs["batch"] expected = expected_outputs["batch"]
for r in responses: for r in responses:
assert r.details.generated_tokens == 32 assert r.details.generated_tokens == 32
# Compute the similarity with the expectation using the levenshtein distance assert r.generated_text == expected
# We should not have more than two substitutions or additions
assert levenshtein_distance(r.generated_text, expected) < 3