Merge branch 'main' into prefer-chat-object-enum

This commit is contained in:
Nicolas Patry 2024-07-01 14:10:45 +02:00 committed by GitHub
commit 153c8ae60f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 186 additions and 734 deletions

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": " has"
},
{
"id": 539,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 3686,
"logprob": 0.0,
"special": false,
"text": " yet"
},
{
"id": 3288,
"logprob": 0.0,
"special": false,
"text": " sent"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -1,338 +0,0 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -1,68 +0,0 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_marlin_handle(launcher):
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
await flash_llama_gptq_marlin_handle.health(300)
return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
response = await flash_llama_gptq_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
flash_llama_gptq_marlin, response_snapshot
):
response = await flash_llama_gptq_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
flash_llama_gptq_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -898,13 +898,20 @@ enum LauncherError {
WebserverCannotStart,
}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![
"download-weights".to_string(),
args.model_id.to_string(),
model_id.to_string(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
];
// Model optional revision
if let Some(revision) = &args.revision {
if let Some(revision) = &revision {
download_args.push("--revision".to_string());
download_args.push(revision.to_string())
}
// Trust remote code for automatic peft fusion
if args.trust_remote_code {
if trust_remote_code {
download_args.push("--trust-remote-code".to_string());
}
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override {
if let Some(weights_cache_override) = &weights_cache_override {
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
};
// Start process
tracing::info!("Starting download process.");
tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server")
.args(download_args)
.env_clear()
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop {
if let Some(status) = download_process.try_wait().unwrap() {
if status.success() {
tracing::info!("Successfully downloaded weights.");
tracing::info!("Successfully downloaded weights for {model_id}");
break;
}
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler");
// Download and convert model weights
download_convert_model(&args, running.clone())?;
download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
}
}
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop

View File

@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
@ -577,7 +577,7 @@ pub fn create_post_processor(
if add_bos_token {
if let Some(bos) = bos_token {
single.push(format!("{}:1", bos.as_str()));
pair.push(format!("{}:1", bos.as_str()));
}
}

View File

@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
)
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass
class GPTQWeight:
qweight: torch.Tensor

View File

@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
GPTQMarlinLinear,
GPTQMarlinWeight,
MarlinLinear,
MarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQMarlin24Weight):
if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,

View File

@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.import_utils import SYSTEM
try:
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and gptq_params.quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym
)
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(

View File

@ -309,7 +309,9 @@ class LlamaMLP(nn.Module):
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)

View File

@ -1,25 +1,15 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights:
def __init__(
self,
@ -212,6 +202,10 @@ class Weights:
"""
if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = self.get_packed_sharded(
@ -221,17 +215,28 @@ class Weights:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
gptq_params = self._get_gptq_params()
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq":
@ -269,7 +274,6 @@ class Weights:
repack_gptq_for_marlin,
)
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
B = self.get_packed_sharded(
@ -286,31 +290,6 @@ class Weights:
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
@ -356,6 +335,10 @@ class Weights:
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
@ -366,14 +349,31 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
@ -425,10 +425,8 @@ class Weights:
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
repack_gptq_for_marlin,
)
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
@ -452,36 +450,6 @@ class Weights:
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
try:
B = torch.cat(
@ -544,9 +512,41 @@ class Weights:
)
elif quantize == "gptq":
use_exllama = True
gptq_params = self._get_gptq_params()
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4:
use_exllama = False
@ -672,10 +672,8 @@ class Weights:
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
repack_gptq_for_marlin,
)
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
@ -698,35 +696,6 @@ class Weights:
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
elif quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else:
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
@ -743,18 +712,17 @@ class Weights:
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> _GPTQParams:
def _get_gptq_params(self) -> GPTQParams:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False
sym = True
sym = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
@ -767,7 +735,7 @@ class Weights:
except Exception:
raise e
return _GPTQParams(
return GPTQParams(
bits=bits,
checkpoint_format=checkpoint_format,
desc_act=desc_act,