mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
fix tests
This commit is contained in:
parent
dadfff621e
commit
7c7470542d
@ -1,20 +1,26 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
def bloom_560_handle(launcher):
|
def bloom_560_handle(launcher):
|
||||||
with launcher("bigscience/bloom-560m") as handle:
|
with launcher("bigscience/bloom-560m") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def bloom_560(bloom_560_handle):
|
async def bloom_560(bloom_560_handle):
|
||||||
await bloom_560_handle.health(240)
|
await bloom_560_handle.health(240)
|
||||||
return bloom_560_handle.client
|
return bloom_560_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m(bloom_560, response_snapshot):
|
async def test_bloom_560m(bloom_560, response_snapshot):
|
||||||
|
# The generated text is different on MI300X, and for what it is worth also different on H100.
|
||||||
response = await bloom_560.generate(
|
response = await bloom_560.generate(
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
@ -28,7 +34,9 @@ async def test_bloom_560m(bloom_560, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
||||||
|
# The generated text is different on MI300X, and for what it is worth also different on H100.
|
||||||
response = await bloom_560.generate(
|
response = await bloom_560.generate(
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
@ -50,7 +58,9 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
|
||||||
|
# The generated text is different on MI300X, and for what it is worth also different on H100.
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
bloom_560,
|
bloom_560,
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def bloom_560m_sharded_handle(launcher):
|
def bloom_560m_sharded_handle(launcher):
|
||||||
@ -14,7 +16,9 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
|
||||||
|
# The generated text is different on MI300X, and for what it is worth also different on H100.
|
||||||
response = await bloom_560m_sharded.generate(
|
response = await bloom_560m_sharded.generate(
|
||||||
"Pour déguster un ortolan, il faut tout d'abord",
|
"Pour déguster un ortolan, il faut tout d'abord",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -1,13 +1,19 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
def flash_gemma_handle(launcher):
|
def flash_gemma_handle(launcher):
|
||||||
with launcher("google/gemma-2b", num_shard=1) as handle:
|
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def flash_gemma(flash_gemma_handle):
|
async def flash_gemma(flash_gemma_handle):
|
||||||
await flash_gemma_handle.health(300)
|
await flash_gemma_handle.health(300)
|
||||||
return flash_gemma_handle.client
|
return flash_gemma_handle.client
|
||||||
@ -15,6 +21,7 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||||
response = await flash_gemma.generate(
|
response = await flash_gemma.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -26,6 +33,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||||
response = await flash_gemma.generate(
|
response = await flash_gemma.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
@ -49,6 +57,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
def flash_gemma_gptq_handle(launcher):
|
def flash_gemma_gptq_handle(launcher):
|
||||||
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
||||||
await flash_gemma_gptq_handle.health(300)
|
await flash_gemma_gptq_handle.health(300)
|
||||||
return flash_gemma_gptq_handle.client
|
return flash_gemma_gptq_handle.client
|
||||||
@ -15,6 +19,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
|
||||||
response = await flash_gemma_gptq.generate(
|
response = await flash_gemma_gptq.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -28,6 +33,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma_gptq_all_params(
|
async def test_flash_gemma_gptq_all_params(
|
||||||
flash_gemma_gptq, ignore_logprob_response_snapshot
|
flash_gemma_gptq, ignore_logprob_response_snapshot
|
||||||
):
|
):
|
||||||
@ -53,6 +59,7 @@ async def test_flash_gemma_gptq_all_params(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_gemma_gptq_load(
|
async def test_flash_gemma_gptq_load(
|
||||||
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
|
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -3,8 +3,13 @@ import requests
|
|||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
def flash_pali_gemma_handle(launcher):
|
def flash_pali_gemma_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"google/paligemma-3b-pt-224",
|
"google/paligemma-3b-pt-224",
|
||||||
@ -17,6 +22,7 @@ def flash_pali_gemma_handle(launcher):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def flash_pali_gemma(flash_pali_gemma_handle):
|
async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
await flash_pali_gemma_handle.health(300)
|
await flash_pali_gemma_handle.health(300)
|
||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
@ -30,6 +36,7 @@ def get_cow_beach():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
cow = get_cow_beach()
|
cow = get_cow_beach()
|
||||||
inputs = f"Where is the cow standing?\n"
|
inputs = f"Where is the cow standing?\n"
|
||||||
|
@ -1,19 +1,26 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
# These tests do not pass on ROCm, with different generations.
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
def flash_phi_handle(launcher):
|
def flash_phi_handle(launcher):
|
||||||
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def flash_phi(flash_phi_handle):
|
async def flash_phi(flash_phi_handle):
|
||||||
await flash_phi_handle.health(300)
|
await flash_phi_handle.health(300)
|
||||||
return flash_phi_handle.client
|
return flash_phi_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi(flash_phi, response_snapshot):
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
@ -25,6 +32,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
@ -48,6 +56,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_santacoder_handle(launcher):
|
def flash_santacoder_handle(launcher):
|
||||||
@ -14,7 +16,9 @@ async def flash_santacoder(flash_santacoder_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda", "xpu")
|
||||||
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
async def test_flash_santacoder(flash_santacoder, response_snapshot):
|
||||||
|
# TODO: This test does not pass on ROCm although it should. To be investigated.
|
||||||
response = await flash_santacoder.generate(
|
response = await flash_santacoder.generate(
|
||||||
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import SYSTEM, is_flaky_async, require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_starcoder_gptq_handle(launcher):
|
def flash_starcoder_gptq_handle(launcher):
|
||||||
@ -14,6 +16,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@is_flaky_async(max_attempts=10)
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
@ -21,10 +24,17 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 20
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== '\n """\n Calculate the geometric mean of a list of numbers.\n\n :param L: List'
|
||||||
|
)
|
||||||
|
|
||||||
|
if SYSTEM != "rocm":
|
||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@is_flaky_async(max_attempts=10)
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
flash_starcoder_gptq, generous_response_snapshot
|
flash_starcoder_gptq, generous_response_snapshot
|
||||||
):
|
):
|
||||||
@ -37,13 +47,21 @@ async def test_flash_starcoder_gptq_default_params(
|
|||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 20
|
assert response.details.generated_tokens == 20
|
||||||
|
assert (
|
||||||
|
response.generated_text == "\n return reduce(lambda x, y: x * y, L) ** (1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if SYSTEM != "rocm":
|
||||||
assert response == generous_response_snapshot
|
assert response == generous_response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
# TODO: exllamav2 gptq kernel is highly non-deterministic on ROCm.
|
||||||
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_starcoder_gptq,
|
flash_starcoder_gptq,
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
from testing_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
def get_chicken():
|
def get_chicken():
|
||||||
@ -81,4 +83,6 @@ async def test_flash_llava_next_load(
|
|||||||
assert len(generated_texts) == 4
|
assert len(generated_texts) == 4
|
||||||
assert all([r.generated_text == generated_texts[0] for r in responses])
|
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||||
|
|
||||||
|
if SYSTEM != "rocm":
|
||||||
|
# Logprobs are not strictly identical on AMD GPUs.
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import require_backend_async
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
def fused_kernel_mamba_handle(launcher):
|
def fused_kernel_mamba_handle(launcher):
|
||||||
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||||
await fused_kernel_mamba_handle.health(300)
|
await fused_kernel_mamba_handle.health(300)
|
||||||
return fused_kernel_mamba_handle.client
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
@ -25,6 +30,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"blue, red, yellow, ",
|
"blue, red, yellow, ",
|
||||||
@ -51,6 +57,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@require_backend_async("cuda")
|
||||||
async def test_mamba_load(
|
async def test_mamba_load(
|
||||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -3,7 +3,8 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def mt0_base_handle(launcher):
|
def mt0_base_handle(launcher):
|
||||||
with launcher("bigscience/mt0-base") as handle:
|
# We use TP=1 as this model is loaded with AutoModel (sharding not supported).
|
||||||
|
with launcher("bigscience/mt0-base", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class FastLinearROCm(torch.nn.Module):
|
|||||||
out = F.linear(inp, weight)
|
out = F.linear(inp, weight)
|
||||||
|
|
||||||
if batched:
|
if batched:
|
||||||
out.view(*inp_shape[:-1], out.shape[-1])
|
out = out.view(*inp_shape[:-1], out.shape[-1])
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
|
@ -105,11 +105,13 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashCohere)
|
__all__.append(FlashCohere)
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
|
MAMBA_IMPORT_ERROR = None
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.mamba import Mamba
|
from text_generation_server.models.mamba import Mamba
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Mamba: {e}")
|
logger.warning(f"Could not import Mamba: {e}")
|
||||||
MAMBA_AVAILABLE = False
|
MAMBA_AVAILABLE = False
|
||||||
|
MAMBA_IMPORT_ERROR = e
|
||||||
|
|
||||||
if MAMBA_AVAILABLE:
|
if MAMBA_AVAILABLE:
|
||||||
__all__.append(Mamba)
|
__all__.append(Mamba)
|
||||||
@ -424,6 +426,11 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == MAMBA:
|
||||||
|
if not MAMBA_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
f"Mamba is not available on the current {SYSTEM} system, with the following error: {MAMBA_IMPORT_ERROR}"
|
||||||
|
)
|
||||||
|
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BLOOM model."""
|
"""PyTorch BLOOM model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
Loading…
Reference in New Issue
Block a user