Santacoder GPTQ support (quantized model seems awful, not sure if it's

prompting or the quantization itself).
This commit is contained in:
Nicolas Patry 2023-06-15 16:59:31 +02:00
parent 983c813f1d
commit 16d0fb04ae
4 changed files with 116 additions and 15 deletions

View File

@ -160,6 +160,8 @@ def quantize(
json_output: bool = False, json_output: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None, upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
): ):
download_weights( download_weights(
model_id=model_id, model_id=model_id,
@ -176,6 +178,8 @@ def quantize(
output_dir=output_dir, output_dir=output_dir,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id, upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
) )

View File

@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
def load_multi_mqa( def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert 2 * head_size % (32 // 4) == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()): if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight") slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -93,7 +168,9 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=0
)
if bias: if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
@ -106,7 +183,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose: if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else: else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1) weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0: if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process # Rank is only on the first rank process

View File

@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained( try:
model_id, tokenizer = LlamaTokenizer.from_pretrained(
revision=revision, model_id,
padding_side="left", revision=revision,
truncation_side="left", padding_side="left",
trust_remote_code=trust_remote_code, truncation_side="left",
) trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code

View File

@ -240,7 +240,7 @@ class GPTQ:
print(table.draw().split("\n")[-2]) print(table.draw().split("\n")[-2])
def fasterquant( def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
): ):
self.layer.to(self.dev) self.layer.to(self.dev)
@ -263,7 +263,7 @@ class GPTQ:
H[dead, dead] = 1 H[dead, dead] = 1
W[:, dead] = 0 W[:, dead] = 0
if actorder: if act_order:
perm = torch.argsort(torch.diag(H), descending=True) perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm] W = W[:, perm]
H = H[perm][:, perm] H = H[perm][:, perm]
@ -334,7 +334,7 @@ class GPTQ:
groupsize = groupsize if groupsize != -1 else self.columns groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)] g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder: if act_order:
invperm = torch.argsort(perm) invperm = torch.argsort(perm)
Q = Q[:, invperm] Q = Q[:, invperm]
g_idx = g_idx[invperm] g_idx = g_idx[invperm]
@ -697,7 +697,7 @@ def sequential(
scale, zero, g_idx, error = gptq[name].fasterquant( scale, zero, g_idx, error = gptq[name].fasterquant(
percdamp=percdamp, percdamp=percdamp,
groupsize=groupsize, groupsize=groupsize,
actorder=act_order, act_order=act_order,
name=name, name=name,
) )
quantizers[f"{prefix}.{i}.{name}"] = ( quantizers[f"{prefix}.{i}.{name}"] = (
@ -775,6 +775,8 @@ def quantize(
output_dir: str, output_dir: str,
trust_remote_code: bool, trust_remote_code: bool,
upload_to_model_id: Optional[str], upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
): ):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -795,7 +797,16 @@ def quantize(
) )
tick = time.time() tick = time.time()
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) quantizers = sequential(
model,
dataloader,
DEV,
nsamples,
bits,
groupsize,
percdamp=percdamp,
act_order=act_order,
)
print(time.time() - tick) print(time.time() - tick)
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)