fix: bump client tests for api changes

This commit is contained in:
drbh 2025-02-17 14:19:52 +00:00
parent 3f035fd8f2
commit 81786840d7
3 changed files with 12 additions and 17 deletions

View File

@ -5,13 +5,13 @@ from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"
@pytest.fixture
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"
def llama_70b():
return "meta-llama/Llama-3.1-70B-Instruct"
@pytest.fixture
@ -21,7 +21,7 @@ def fake_model():
@pytest.fixture
def unsupported_model():
return "gpt2"
return "black-forest-labs/FLUX.1-dev"
@pytest.fixture
@ -34,11 +34,6 @@ def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def llama_7b_url(base_url, llama_7b):
return f"{base_url}/{llama_7b}"

View File

@ -10,8 +10,8 @@ from text_generation.errors import NotSupportedError, NotFoundError
from text_generation.inference_api import check_model_support, deployed_models
def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
assert check_model_support(flan_t5_xxl)
def test_check_model_support(llama_70b, unsupported_model, fake_model):
assert check_model_support(llama_70b)
assert not check_model_support(unsupported_model)
with pytest.raises(NotFoundError):
@ -22,8 +22,8 @@ def test_deployed_models():
deployed_models()
def test_client(flan_t5_xxl):
client = InferenceAPIClient(flan_t5_xxl)
def test_client(llama_70b):
client = InferenceAPIClient(llama_70b)
assert isinstance(client, Client)
@ -32,8 +32,8 @@ def test_client_unsupported_model(unsupported_model):
InferenceAPIClient(unsupported_model)
def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(flan_t5_xxl)
def test_async_client(llama_70b):
client = InferenceAPIAsyncClient(llama_70b)
assert isinstance(client, AsyncClient)

View File

@ -464,4 +464,4 @@ class DeployedModel(BaseModel):
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_id: str
sha: str
task: str