mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Santacoder GPTQ support (quantized model seems awful, not sure if it's
prompting or the quantization itself).
This commit is contained in:
parent
983c813f1d
commit
16d0fb04ae
@ -160,6 +160,8 @@ def quantize(
|
|||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
upload_to_model_id: Optional[str] = None,
|
upload_to_model_id: Optional[str] = None,
|
||||||
|
percdamp: float = 0.01,
|
||||||
|
act_order: bool = False,
|
||||||
):
|
):
|
||||||
download_weights(
|
download_weights(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -176,6 +178,8 @@ def quantize(
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
upload_to_model_id=upload_to_model_id,
|
upload_to_model_id=upload_to_model_id,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
|
|||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if config.quantize == "gptq":
|
||||||
|
return _load_multi_mqa_gptq(
|
||||||
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _load_multi_mqa(
|
||||||
|
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_multi_mqa_gptq(
|
||||||
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
|
):
|
||||||
|
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - 2 * head_size) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
|
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - 2 * head_size) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size :]
|
||||||
|
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
assert 2 * head_size % (32 // 4) == 0
|
||||||
|
q_tensor = slice_[:, start:stop]
|
||||||
|
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||||
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
|
bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
|
shape = slice_.get_shape()
|
||||||
|
block_size = (shape[0] - 2 * head_size) // world_size
|
||||||
|
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||||
|
q_tensor = slice_[start:stop]
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
q_tensor = slice_[start:stop]
|
||||||
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_multi_mqa(
|
||||||
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
|
):
|
||||||
|
|
||||||
if any("c_attn" in k for k in weights.routing.keys()):
|
if any("c_attn" in k for k in weights.routing.keys()):
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
@ -93,7 +168,9 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_multi_weights_col(
|
||||||
|
[prefix], quantize=config.quantize, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||||
@ -106,7 +183,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||||||
if config.transpose:
|
if config.transpose:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||||
else:
|
else:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||||
|
|
||||||
if bias and weights.process_group.rank() == 0:
|
if bias and weights.process_group.rank() == 0:
|
||||||
# Rank is only on the first rank process
|
# Rank is only on the first rank process
|
||||||
|
@ -3,7 +3,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from transformers.models.llama import LlamaTokenizer
|
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
try:
|
||||||
model_id,
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
revision=revision,
|
model_id,
|
||||||
padding_side="left",
|
revision=revision,
|
||||||
truncation_side="left",
|
padding_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
truncation_side="left",
|
||||||
)
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
@ -240,7 +240,7 @@ class GPTQ:
|
|||||||
print(table.draw().split("\n")[-2])
|
print(table.draw().split("\n")[-2])
|
||||||
|
|
||||||
def fasterquant(
|
def fasterquant(
|
||||||
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
|
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
|
||||||
):
|
):
|
||||||
self.layer.to(self.dev)
|
self.layer.to(self.dev)
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ class GPTQ:
|
|||||||
H[dead, dead] = 1
|
H[dead, dead] = 1
|
||||||
W[:, dead] = 0
|
W[:, dead] = 0
|
||||||
|
|
||||||
if actorder:
|
if act_order:
|
||||||
perm = torch.argsort(torch.diag(H), descending=True)
|
perm = torch.argsort(torch.diag(H), descending=True)
|
||||||
W = W[:, perm]
|
W = W[:, perm]
|
||||||
H = H[perm][:, perm]
|
H = H[perm][:, perm]
|
||||||
@ -334,7 +334,7 @@ class GPTQ:
|
|||||||
groupsize = groupsize if groupsize != -1 else self.columns
|
groupsize = groupsize if groupsize != -1 else self.columns
|
||||||
g_idx = [i // groupsize for i in range(self.columns)]
|
g_idx = [i // groupsize for i in range(self.columns)]
|
||||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||||
if actorder:
|
if act_order:
|
||||||
invperm = torch.argsort(perm)
|
invperm = torch.argsort(perm)
|
||||||
Q = Q[:, invperm]
|
Q = Q[:, invperm]
|
||||||
g_idx = g_idx[invperm]
|
g_idx = g_idx[invperm]
|
||||||
@ -697,7 +697,7 @@ def sequential(
|
|||||||
scale, zero, g_idx, error = gptq[name].fasterquant(
|
scale, zero, g_idx, error = gptq[name].fasterquant(
|
||||||
percdamp=percdamp,
|
percdamp=percdamp,
|
||||||
groupsize=groupsize,
|
groupsize=groupsize,
|
||||||
actorder=act_order,
|
act_order=act_order,
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
quantizers[f"{prefix}.{i}.{name}"] = (
|
quantizers[f"{prefix}.{i}.{name}"] = (
|
||||||
@ -775,6 +775,8 @@ def quantize(
|
|||||||
output_dir: str,
|
output_dir: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
upload_to_model_id: Optional[str],
|
upload_to_model_id: Optional[str],
|
||||||
|
percdamp: float,
|
||||||
|
act_order: bool,
|
||||||
):
|
):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@ -795,7 +797,16 @@ def quantize(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tick = time.time()
|
tick = time.time()
|
||||||
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
quantizers = sequential(
|
||||||
|
model,
|
||||||
|
dataloader,
|
||||||
|
DEV,
|
||||||
|
nsamples,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
percdamp=percdamp,
|
||||||
|
act_order=act_order,
|
||||||
|
)
|
||||||
print(time.time() - tick)
|
print(time.time() - tick)
|
||||||
|
|
||||||
pack(model, quantizers, bits, groupsize)
|
pack(model, quantizers, bits, groupsize)
|
||||||
|
Loading…
Reference in New Issue
Block a user