Enable Marlin for supported AWQ configurations by default

This makes the AWQ -> GPTQ repack test redundant, since we are now
testing this with the regular AWQ test.
This commit is contained in:
Daniël de Kok 2024-07-23 09:31:36 +00:00
parent 32794b1caa
commit 712729bc78
6 changed files with 9 additions and 576 deletions

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.2890625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 311,
"logprob": -2.5566406,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": -2.0117188,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": -1.3105469,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -2.1679688,
"special": false,
"text": ".\n"
},
{
"id": 262,
"logprob": -1.640625,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -1.1865234,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.007183075,
"special": false,
"text": " "
},
{
"id": 711,
"logprob": -1.7636719,
"special": false,
"text": " def"
},
{
"id": 1328,
"logprob": -1.0673828,
"special": false,
"text": " __"
},
{
"id": 2381,
"logprob": -0.018508911,
"special": false,
"text": "init"
}
],
"top_tokens": null
},
"generated_text": " to the server.\n \"\"\"\n def __init"
}

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.328125,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 311,
"logprob": -0.50927734,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": 0.0,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -0.5107422,
"special": false,
"text": ".\n"
},
{
"id": 257,
"logprob": -1.5878906,
"special": false,
"text": " "
},
{
"id": 1235,
"logprob": -0.24499512,
"special": false,
"text": " *\n"
},
{
"id": 257,
"logprob": 0.0,
"special": false,
"text": " "
},
{
"id": 353,
"logprob": 0.0,
"special": false,
"text": " *"
},
{
"id": 571,
"logprob": 0.0,
"special": false,
"text": " @"
},
{
"id": 913,
"logprob": 0.0,
"special": false,
"text": "param"
}
],
"top_tokens": null
},
"generated_text": "Test request to the server.\n *\n * @param"
}

View File

@ -1,338 +0,0 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.328125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 311,
"logprob": -2.5585938,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": -2.0253906,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": -1.3125,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -2.171875,
"special": false,
"text": ".\n"
},
{
"id": 262,
"logprob": -1.6396484,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -1.1884766,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.0073013306,
"special": false,
"text": " "
},
{
"id": 711,
"logprob": -1.7568359,
"special": false,
"text": " def"
},
{
"id": 1328,
"logprob": -1.0595703,
"special": false,
"text": " __"
},
{
"id": 2381,
"logprob": -0.018676758,
"special": false,
"text": "init"
}
],
"top_tokens": null
},
"generated_text": " to the server.\n \"\"\"\n def __init"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.3046875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 311,
"logprob": -2.5527344,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": -2.0253906,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": -1.3027344,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -2.1757812,
"special": false,
"text": ".\n"
},
{
"id": 262,
"logprob": -1.6445312,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -1.1875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.00730896,
"special": false,
"text": " "
},
{
"id": 711,
"logprob": -1.7587891,
"special": false,
"text": " def"
},
{
"id": 1328,
"logprob": -1.0605469,
"special": false,
"text": " __"
},
{
"id": 2381,
"logprob": -0.01890564,
"special": false,
"text": "init"
}
],
"top_tokens": null
},
"generated_text": " to the server.\n \"\"\"\n def __init"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.3125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 311,
"logprob": -2.5585938,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": -2.0292969,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": -1.3095703,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -2.1816406,
"special": false,
"text": ".\n"
},
{
"id": 262,
"logprob": -1.6396484,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -1.1875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.007194519,
"special": false,
"text": " "
},
{
"id": 711,
"logprob": -1.765625,
"special": false,
"text": " def"
},
{
"id": 1328,
"logprob": -1.0537109,
"special": false,
"text": " __"
},
{
"id": 2381,
"logprob": -0.018432617,
"special": false,
"text": "init"
}
],
"top_tokens": null
},
"generated_text": " to the server.\n \"\"\"\n def __init"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -12.296875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 311,
"logprob": -2.5585938,
"special": false,
"text": " to"
},
{
"id": 279,
"logprob": -2.0136719,
"special": false,
"text": " the"
},
{
"id": 3622,
"logprob": -1.3164062,
"special": false,
"text": " server"
},
{
"id": 627,
"logprob": -2.1601562,
"special": false,
"text": ".\n"
},
{
"id": 262,
"logprob": -1.6455078,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -1.1865234,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.0071907043,
"special": false,
"text": " "
},
{
"id": 711,
"logprob": -1.7568359,
"special": false,
"text": " def"
},
{
"id": 1328,
"logprob": -1.0605469,
"special": false,
"text": " __"
},
{
"id": 2381,
"logprob": -0.018585205,
"special": false,
"text": "init"
}
],
"top_tokens": null
},
"generated_text": " to the server.\n \"\"\"\n def __init"
}
]

View File

@ -1,68 +0,0 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_awq_gptq_handle(launcher):
with launcher(
"casperhansen/llama-3-8b-instruct-awq",
num_shard=2,
quantize="gptq",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_awq_gptq(flash_llama_awq_gptq_handle):
await flash_llama_awq_gptq_handle.health(300)
return flash_llama_awq_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_gptq(flash_llama_awq_gptq, response_snapshot):
response = await flash_llama_awq_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_gptq_all_params(flash_llama_awq_gptq, response_snapshot):
response = await flash_llama_awq_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_gptq_load(
flash_llama_awq_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_awq_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -175,7 +175,7 @@ def can_use_gptq_marlin(
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES

View File

@ -54,6 +54,7 @@ def _get_quantizer_config(model_id, revision):
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
@ -76,7 +77,13 @@ def _get_quantizer_config(model_id, revision):
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"