mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Merge branch 'main' into prefer-chat-object-enum
This commit is contained in:
commit
153c8ae60f
@ -1,84 +0,0 @@
|
|||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6230469,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.046875,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1425781,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.9238281,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.076660156,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10821533,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
}
|
|
@ -1,84 +0,0 @@
|
|||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": 0,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -2.2539062,
|
|
||||||
"special": false,
|
|
||||||
"text": "."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 578,
|
|
||||||
"logprob": -0.15563965,
|
|
||||||
"special": false,
|
|
||||||
"text": " The"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3622,
|
|
||||||
"logprob": -0.8203125,
|
|
||||||
"special": false,
|
|
||||||
"text": " server"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 706,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " has"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 539,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " not"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3686,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " yet"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3288,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " sent"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 904,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " any"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 828,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 382,
|
|
||||||
"logprob": -1.5517578,
|
|
||||||
"special": false,
|
|
||||||
"text": ".\n\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
|
||||||
}
|
|
@ -1,338 +0,0 @@
|
|||||||
[
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
}
|
|
||||||
]
|
|
@ -1,68 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def flash_llama_gptq_marlin_handle(launcher):
|
|
||||||
with launcher(
|
|
||||||
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
|
||||||
await flash_llama_gptq_marlin_handle.health(300)
|
|
||||||
return flash_llama_gptq_marlin_handle.client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
|
||||||
response = await flash_llama_gptq_marlin.generate(
|
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin_all_params(
|
|
||||||
flash_llama_gptq_marlin, response_snapshot
|
|
||||||
):
|
|
||||||
response = await flash_llama_gptq_marlin.generate(
|
|
||||||
"Test request",
|
|
||||||
max_new_tokens=10,
|
|
||||||
repetition_penalty=1.2,
|
|
||||||
return_full_text=True,
|
|
||||||
temperature=0.5,
|
|
||||||
top_p=0.9,
|
|
||||||
top_k=10,
|
|
||||||
truncate=5,
|
|
||||||
typical_p=0.9,
|
|
||||||
watermark=True,
|
|
||||||
decoder_input_details=True,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin_load(
|
|
||||||
flash_llama_gptq_marlin, generate_load, response_snapshot
|
|
||||||
):
|
|
||||||
responses = await generate_load(
|
|
||||||
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(responses) == 4
|
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
|
@ -898,13 +898,20 @@ enum LauncherError {
|
|||||||
WebserverCannotStart,
|
WebserverCannotStart,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
fn download_convert_model(
|
||||||
|
model_id: &str,
|
||||||
|
revision: Option<&str>,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
huggingface_hub_cache: Option<&str>,
|
||||||
|
weights_cache_override: Option<&str>,
|
||||||
|
running: Arc<AtomicBool>,
|
||||||
|
) -> Result<(), LauncherError> {
|
||||||
// Enter download tracing span
|
// Enter download tracing span
|
||||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
|
||||||
let mut download_args = vec![
|
let mut download_args = vec![
|
||||||
"download-weights".to_string(),
|
"download-weights".to_string(),
|
||||||
args.model_id.to_string(),
|
model_id.to_string(),
|
||||||
"--extension".to_string(),
|
"--extension".to_string(),
|
||||||
".safetensors".to_string(),
|
".safetensors".to_string(),
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
];
|
];
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = &args.revision {
|
if let Some(revision) = &revision {
|
||||||
download_args.push("--revision".to_string());
|
download_args.push("--revision".to_string());
|
||||||
download_args.push(revision.to_string())
|
download_args.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trust remote code for automatic peft fusion
|
// Trust remote code for automatic peft fusion
|
||||||
if args.trust_remote_code {
|
if trust_remote_code {
|
||||||
download_args.push("--trust-remote-code".to_string());
|
download_args.push("--trust-remote-code".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
|
|
||||||
// If huggingface_hub_cache is set, pass it to the download process
|
// If huggingface_hub_cache is set, pass it to the download process
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
|
|
||||||
// If args.weights_cache_override is some, pass it to the download process
|
// If args.weights_cache_override is some, pass it to the download process
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
if let Some(weights_cache_override) = &args.weights_cache_override {
|
if let Some(weights_cache_override) = &weights_cache_override {
|
||||||
envs.push((
|
envs.push((
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
weights_cache_override.into(),
|
weights_cache_override.into(),
|
||||||
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting download process.");
|
tracing::info!("Starting check and download process for {model_id}");
|
||||||
let mut download_process = match Command::new("text-generation-server")
|
let mut download_process = match Command::new("text-generation-server")
|
||||||
.args(download_args)
|
.args(download_args)
|
||||||
.env_clear()
|
.env_clear()
|
||||||
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
loop {
|
loop {
|
||||||
if let Some(status) = download_process.try_wait().unwrap() {
|
if let Some(status) = download_process.try_wait().unwrap() {
|
||||||
if status.success() {
|
if status.success() {
|
||||||
tracing::info!("Successfully downloaded weights.");
|
tracing::info!("Successfully downloaded weights for {model_id}");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
// Download and convert model weights
|
// Download and convert model weights
|
||||||
download_convert_model(&args, running.clone())?;
|
download_convert_model(
|
||||||
|
&args.model_id,
|
||||||
|
args.revision.as_deref(),
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.huggingface_hub_cache.as_deref(),
|
||||||
|
args.weights_cache_override.as_deref(),
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Download and convert lora adapters if any
|
||||||
|
if let Some(lora_adapters) = &args.lora_adapters {
|
||||||
|
for adapter in lora_adapters.split(',') {
|
||||||
|
download_convert_model(
|
||||||
|
adapter,
|
||||||
|
None,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.huggingface_hub_cache.as_deref(),
|
||||||
|
args.weights_cache_override.as_deref(),
|
||||||
|
running.clone(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
if !running.load(Ordering::SeqCst) {
|
||||||
// Launcher was asked to stop
|
// Launcher was asked to stop
|
||||||
|
@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
if let Some(tokenizer) = &mut tokenizer {
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||||
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
|
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||||
tokenizer.with_post_processor(post_processor);
|
tokenizer.with_post_processor(post_processor);
|
||||||
@ -577,7 +577,7 @@ pub fn create_post_processor(
|
|||||||
|
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
if let Some(bos) = bos_token {
|
if let Some(bos) = bos_token {
|
||||||
single.push(format!("{}:1", bos.as_str()));
|
pair.push(format!("{}:1", bos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQParams:
|
||||||
|
bits: int
|
||||||
|
checkpoint_format: Optional[str]
|
||||||
|
groupsize: int
|
||||||
|
desc_act: bool
|
||||||
|
quant_method: str
|
||||||
|
sym: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQWeight:
|
class GPTQWeight:
|
||||||
qweight: torch.Tensor
|
qweight: torch.Tensor
|
||||||
|
@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
|
|||||||
|
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
if not isinstance(weight, GPTQWeight):
|
GPTQMarlinLinear,
|
||||||
raise NotImplementedError(
|
GPTQMarlinWeight,
|
||||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, GPTQMarlinWeight):
|
||||||
|
linear = GPTQMarlinLinear(
|
||||||
|
weight=weight,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
elif isinstance(weight, GPTQWeight):
|
||||||
if weight.use_exllama:
|
if weight.use_exllama:
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.gptq import (
|
from text_generation_server.layers.gptq import (
|
||||||
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
|
|||||||
weight.bits,
|
weight.bits,
|
||||||
weight.groupsize,
|
weight.groupsize,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
|
)
|
||||||
|
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
|
||||||
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
|
|||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Linear,
|
GPTQMarlin24Linear,
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
GPTQMarlinLinear,
|
|
||||||
GPTQMarlinWeight,
|
|
||||||
MarlinLinear,
|
MarlinLinear,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(weight, GPTQMarlinWeight):
|
if isinstance(weight, GPTQMarlin24Weight):
|
||||||
linear = GPTQMarlinLinear(
|
|
||||||
weight=weight,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
elif isinstance(weight, GPTQMarlin24Weight):
|
|
||||||
linear = GPTQMarlin24Linear(
|
linear = GPTQMarlin24Linear(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import GPTQParams
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|||||||
MARLIN_TILE_SIZE = 16
|
MARLIN_TILE_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||||
|
return (
|
||||||
|
SYSTEM == "cuda"
|
||||||
|
and marlin_kernels is not None
|
||||||
|
and has_sm_8_0
|
||||||
|
and quantize == "gptq"
|
||||||
|
and gptq_params.quant_method == "gptq"
|
||||||
|
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||||
|
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||||
|
and gptq_params.sym
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_marlin_kernels():
|
def _check_marlin_kernels():
|
||||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
if not (SYSTEM == "cuda" and has_sm_8_0):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -309,7 +309,9 @@ class LlamaMLP(nn.Module):
|
|||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
_custom_C.LLMM_Silu(
|
||||||
|
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
||||||
|
)
|
||||||
return self.down_proj(out, adapter_data)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
|
@ -1,25 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
from text_generation_server.layers.gptq import GPTQParams
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _GPTQParams:
|
|
||||||
bits: int
|
|
||||||
checkpoint_format: Optional[str]
|
|
||||||
groupsize: int
|
|
||||||
desc_act: bool
|
|
||||||
quant_method: str
|
|
||||||
sym: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -212,6 +202,10 @@ class Weights:
|
|||||||
"""
|
"""
|
||||||
if quantize in ["gptq", "awq"]:
|
if quantize in ["gptq", "awq"]:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_packed_sharded(
|
qweight = self.get_packed_sharded(
|
||||||
@ -221,17 +215,28 @@ class Weights:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
)
|
)
|
||||||
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
|
|
||||||
qzeros = self.get_packed_sharded(
|
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
scales = self.get_packed_sharded(
|
scales = self.get_packed_sharded(
|
||||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
)
|
)
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = self.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||||
@ -269,7 +274,6 @@ class Weights:
|
|||||||
repack_gptq_for_marlin,
|
repack_gptq_for_marlin,
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
B = self.get_packed_sharded(
|
B = self.get_packed_sharded(
|
||||||
@ -286,31 +290,6 @@ class Weights:
|
|||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
try:
|
|
||||||
qweight = self.get_packed_sharded(
|
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = self.get_packed_sharded(
|
|
||||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
B = self.get_packed_sharded(
|
B = self.get_packed_sharded(
|
||||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||||
@ -356,6 +335,10 @@ class Weights:
|
|||||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
elif quantize in ["gptq", "awq"]:
|
elif quantize in ["gptq", "awq"]:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
@ -366,14 +349,31 @@ class Weights:
|
|||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
scales = torch.cat(
|
scales = torch.cat(
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
gptq_params = self._get_gptq_params()
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||||
|
|
||||||
@ -425,10 +425,8 @@ class Weights:
|
|||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
try:
|
try:
|
||||||
@ -452,36 +450,6 @@ class Weights:
|
|||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
try:
|
|
||||||
qweight = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = torch.cat(
|
B = torch.cat(
|
||||||
@ -544,9 +512,41 @@ class Weights:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
use_exllama = True
|
from text_generation_server.layers.marlin import (
|
||||||
gptq_params = self._get_gptq_params()
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||||
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
else:
|
||||||
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
sharded_in_features = self.process_group.size() > 1
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=sharded_in_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
if gptq_params.bits != 4:
|
if gptq_params.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
@ -672,10 +672,8 @@ class Weights:
|
|||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
try:
|
try:
|
||||||
@ -698,35 +696,6 @@ class Weights:
|
|||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
|
|
||||||
try:
|
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
|
||||||
else:
|
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
|
|
||||||
sharded_in_features = self.process_group.size() > 1
|
|
||||||
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=sharded_in_features,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||||
@ -743,18 +712,17 @@ class Weights:
|
|||||||
else:
|
else:
|
||||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||||
weight = MarlinWeight(B=B, s=s)
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> _GPTQParams:
|
def _get_gptq_params(self) -> GPTQParams:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||||
desc_act = False
|
desc_act = False
|
||||||
sym = True
|
sym = False
|
||||||
quant_method = "gptq"
|
quant_method = "gptq"
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
@ -767,7 +735,7 @@ class Weights:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return _GPTQParams(
|
return GPTQParams(
|
||||||
bits=bits,
|
bits=bits,
|
||||||
checkpoint_format=checkpoint_format,
|
checkpoint_format=checkpoint_format,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
|
Loading…
Reference in New Issue
Block a user