mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
fix decorators
This commit is contained in:
parent
4616c62914
commit
3de8f3647b
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda")
|
@require_backend("cuda")
|
||||||
def bloom_560_handle(launcher):
|
def bloom_560_handle(launcher):
|
||||||
with launcher("bigscience/bloom-560m") as handle:
|
with launcher("bigscience/bloom-560m") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import SYSTEM, is_flaky_async, require_backend_async
|
from testing_utils import SYSTEM, is_flaky_async, require_backend_async, require_backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda", "rocm")
|
@require_backend("cuda", "rocm")
|
||||||
def flash_llama_awq_handle_sharded(launcher):
|
def flash_llama_awq_handle_sharded(launcher):
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend("cuda", "xpu")
|
||||||
def flash_gemma_handle(launcher):
|
def flash_gemma_handle(launcher):
|
||||||
with launcher("google/gemma-2b", num_shard=1) as handle:
|
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
|
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend("cuda", "xpu")
|
||||||
def flash_gemma_gptq_handle(launcher):
|
def flash_gemma_gptq_handle(launcher):
|
||||||
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda")
|
@require_backend("cuda")
|
||||||
def flash_llama_exl2_handle(launcher):
|
def flash_llama_exl2_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||||
|
@ -3,13 +3,13 @@ import requests
|
|||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda", "xpu")
|
@require_backend("cuda", "xpu")
|
||||||
def flash_pali_gemma_handle(launcher):
|
def flash_pali_gemma_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"google/paligemma-3b-pt-224",
|
"google/paligemma-3b-pt-224",
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
# These tests do not pass on ROCm, with different generations.
|
# These tests do not pass on ROCm, with different generations.
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda")
|
@require_backend("cuda")
|
||||||
def flash_phi_handle(launcher):
|
def flash_phi_handle(launcher):
|
||||||
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from testing_utils import require_backend_async
|
from testing_utils import require_backend_async, require_backend
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@require_backend_async("cuda")
|
@require_backend("cuda")
|
||||||
def fused_kernel_mamba_handle(launcher):
|
def fused_kernel_mamba_handle(launcher):
|
||||||
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -51,6 +51,20 @@ def is_flaky_async(
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def require_backend(*args):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*wrapper_args, **wrapper_kwargs):
|
||||||
|
if SYSTEM not in args:
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping as this test requires the backend {args} to be run, but current system is SYSTEM={SYSTEM}."
|
||||||
|
)
|
||||||
|
return func(*wrapper_args, **wrapper_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def require_backend_async(*args):
|
def require_backend_async(*args):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
Loading…
Reference in New Issue
Block a user