mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
Add support for GPTQ Marlin kernels GPTQ Marlin extends the Marlin kernels to support common GPTQ configurations: - bits: 4 or 8 - groupsize: -1, 32, 64, or 128 - desc_act: true/false Using the GPTQ Marlin kernels requires repacking the parameters in the Marlin quantizer format. The kernels were contributed by Neural Magic to VLLM. We vendor them here for convenience.
66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
import pytest
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def flash_llama_gptq_marlin_handle(launcher):
|
|
with launcher(
|
|
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
|
await flash_llama_gptq_marlin_handle.health(300)
|
|
return flash_llama_gptq_marlin_handle.client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
|
response = await flash_llama_gptq_marlin.generate(
|
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
|
)
|
|
|
|
assert response.details.generated_tokens == 10
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_gptq_marlin_all_params(
|
|
flash_llama_gptq_marlin, response_snapshot
|
|
):
|
|
response = await flash_llama_gptq_marlin.generate(
|
|
"Test request",
|
|
max_new_tokens=10,
|
|
repetition_penalty=1.2,
|
|
return_full_text=True,
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
top_k=10,
|
|
truncate=5,
|
|
typical_p=0.9,
|
|
watermark=True,
|
|
decoder_input_details=True,
|
|
seed=0,
|
|
)
|
|
|
|
assert response.details.generated_tokens == 10
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_gptq_marlin_load(
|
|
flash_llama_gptq_marlin, generate_load, response_snapshot
|
|
):
|
|
responses = await generate_load(
|
|
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
|
|
)
|
|
|
|
assert len(responses) == 4
|
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
|
|
|
assert responses == response_snapshot
|