Santacoder GPTQ support (quantized model seems awful, not sure if it's

prompting or the quantization itself).
This commit is contained in:
Nicolas Patry 2023-06-15 16:59:31 +02:00
parent 983c813f1d
commit 16d0fb04ae
4 changed files with 116 additions and 15 deletions

View File

@ -160,6 +160,8 @@ def quantize(
json_output: bool = False,
trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
):
download_weights(
model_id=model_id,
@ -176,6 +178,8 @@ def quantize(
output_dir=output_dir,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
)

View File

@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert 2 * head_size % (32 // 4) == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
@ -93,7 +168,9 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=0
)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
@ -106,7 +183,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process

View File

@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional
from text_generation_server.models import FlashCausalLM
@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code

View File

@ -240,7 +240,7 @@ class GPTQ:
print(table.draw().split("\n")[-2])
def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
):
self.layer.to(self.dev)
@ -263,7 +263,7 @@ class GPTQ:
H[dead, dead] = 1
W[:, dead] = 0
if actorder:
if act_order:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
@ -334,7 +334,7 @@ class GPTQ:
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
if act_order:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
@ -697,7 +697,7 @@ def sequential(
scale, zero, g_idx, error = gptq[name].fasterquant(
percdamp=percdamp,
groupsize=groupsize,
actorder=act_order,
act_order=act_order,
name=name,
)
quantizers[f"{prefix}.{i}.{name}"] = (
@ -775,6 +775,8 @@ def quantize(
output_dir: str,
trust_remote_code: bool,
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
):
print("loading model")
model = AutoModelForCausalLM.from_pretrained(
@ -795,7 +797,16 @@ def quantize(
)
tick = time.time()
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
quantizers = sequential(
model,
dataloader,
DEV,
nsamples,
bits,
groupsize,
percdamp=percdamp,
act_order=act_order,
)
print(time.time() - tick)
pack(model, quantizers, bits, groupsize)