diff --git a/integration-tests/fixtures/gaudi/service.py b/integration-tests/fixtures/gaudi/service.py index 5c7d729b..f4f43691 100644 --- a/integration-tests/fixtures/gaudi/service.py +++ b/integration-tests/fixtures/gaudi/service.py @@ -15,7 +15,6 @@ import pytest from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from docker.errors import NotFound import logging -from gaudi.test_gaudi_generate import TEST_CONFIGS from huggingface_hub import AsyncInferenceClient, TextGenerationOutput import huggingface_hub @@ -166,7 +165,7 @@ def gaudi_launcher(): model_id: str, test_name: str, tgi_args: List[str] = None, - env_config: dict = None + env_config: dict = None, ): logger.info( f"Starting docker launcher for model {model_id} and test {test_name}" diff --git a/integration-tests/gaudi/test_gaudi_generate.py b/integration-tests/gaudi/test_gaudi_generate.py index f5d71ab7..2b8b0c76 100644 --- a/integration-tests/gaudi/test_gaudi_generate.py +++ b/integration-tests/gaudi/test_gaudi_generate.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Generator from _pytest.fixtures import SubRequest -from huggingface_hub import AsyncInferenceClient, TextGenerationOutput +from huggingface_hub import AsyncInferenceClient import pytest @@ -237,12 +237,14 @@ def input(test_config: Dict[str, Any]) -> str: @pytest.fixture(scope="module") -def tgi_service(gaudi_launcher, model_id: str, test_name: str, test_config: Dict[str, Any]): +def tgi_service( + gaudi_launcher, model_id: str, test_name: str, test_config: Dict[str, Any] +): with gaudi_launcher( - model_id, - test_name, + model_id, + test_name, tgi_args=test_config.get("args", []), - env_config=test_config.get("env_config", {}) + env_config=test_config.get("env_config", {}), ) as tgi_service: yield tgi_service