mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
skip exl2 tests on rocm
This commit is contained in:
parent
b452620c04
commit
73b067d193
@ -9,7 +9,7 @@ def flash_llama_awq_handle(launcher):
|
|||||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||||
quantize = "gptq"
|
quantize = "gptq"
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
pytest.skiptest("AWQ is not supported on xpu")
|
pytest.skip("AWQ is not supported on xpu")
|
||||||
else:
|
else:
|
||||||
quantize = "awq"
|
quantize = "awq"
|
||||||
|
|
||||||
|
@ -4,12 +4,11 @@ from testing_utils import SYSTEM, is_flaky_async
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "rocm")
|
||||||
def flash_llama_awq_handle_sharded(launcher):
|
def flash_llama_awq_handle_sharded(launcher):
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||||
quantize = "gptq"
|
quantize = "gptq"
|
||||||
elif SYSTEM == "xpu":
|
|
||||||
pytest.skiptest("AWQ is not supported on xpu")
|
|
||||||
else:
|
else:
|
||||||
quantize = "awq"
|
quantize = "awq"
|
||||||
|
|
||||||
@ -22,6 +21,7 @@ def flash_llama_awq_handle_sharded(launcher):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "rocm")
|
||||||
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||||
await flash_llama_awq_handle_sharded.health(300)
|
await flash_llama_awq_handle_sharded.health(300)
|
||||||
return flash_llama_awq_handle_sharded.client
|
return flash_llama_awq_handle_sharded.client
|
||||||
@ -29,6 +29,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
|||||||
|
|
||||||
@is_flaky_async(max_attempts=5)
|
@is_flaky_async(max_attempts=5)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda", "rocm")
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
response = await flash_llama_awq_sharded.generate(
|
response = await flash_llama_awq_sharded.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -47,14 +48,12 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@require_backend_async("cuda")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
if SYSTEM == "rocm":
|
# TODO: This test is highly non-deterministic on ROCm.
|
||||||
pytest.skiptest(
|
|
||||||
"This test relies on ExllamaV2 on ROCm systems, which is highly non-determinstic (flaky)"
|
|
||||||
)
|
|
||||||
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
def flash_llama_exl2_handle(launcher):
|
def flash_llama_exl2_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||||
@ -16,11 +18,13 @@ def flash_llama_exl2_handle(launcher):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def flash_llama_exl2(flash_llama_exl2_handle):
|
async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||||
await flash_llama_exl2_handle.health(300)
|
await flash_llama_exl2_handle.health(300)
|
||||||
return flash_llama_exl2_handle.client
|
return flash_llama_exl2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@require_backend_async("cuda")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||||
@ -32,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@require_backend_async("cuda")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_all_params(
|
async def test_flash_llama_exl2_all_params(
|
||||||
@ -58,6 +63,7 @@ async def test_flash_llama_exl2_all_params(
|
|||||||
assert response == ignore_logprob_response_snapshot
|
assert response == ignore_logprob_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@require_backend_async("cuda")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_exl2_load(
|
async def test_flash_llama_exl2_load(
|
||||||
|
Loading…
Reference in New Issue
Block a user