mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Santacoder GPTQ support (quantized model seems awful, not sure if it's
prompting or the quantization itself).
This commit is contained in:
parent
983c813f1d
commit
16d0fb04ae
@ -160,6 +160,8 @@ def quantize(
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
upload_to_model_id: Optional[str] = None,
|
||||
percdamp: float = 0.01,
|
||||
act_order: bool = False,
|
||||
):
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
@ -176,6 +178,8 @@ def quantize(
|
||||
output_dir=output_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
upload_to_model_id=upload_to_model_id,
|
||||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
)
|
||||
|
||||
|
||||
|
@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
||||
if config.quantize == "gptq":
|
||||
return _load_multi_mqa_gptq(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
)
|
||||
else:
|
||||
return _load_multi_mqa(
|
||||
config, prefix, weights, bias, head_size, num_heads, hidden_size
|
||||
)
|
||||
|
||||
|
||||
def _load_multi_mqa_gptq(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
|
||||
shape = slice_.get_shape()
|
||||
block_size = (shape[1] - 2 * head_size) // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||
shape = slice_.get_shape()
|
||||
block_size = (shape[1] - 2 * head_size) // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||
shape = slice_.get_shape()
|
||||
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
assert 2 * head_size % (32 // 4) == 0
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
|
||||
if bias:
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
block_size = (shape[0] - 2 * head_size) // world_size
|
||||
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[start:stop]
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q_tensor = slice_[start:stop]
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
else:
|
||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||
|
||||
|
||||
def _load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
||||
if any("c_attn" in k for k in weights.routing.keys()):
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
@ -93,7 +168,9 @@ def load_col(config, prefix: str, weights, bias: bool):
|
||||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=0
|
||||
)
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
@ -106,7 +183,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
||||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -3,7 +3,7 @@ import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
try:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
|
@ -240,7 +240,7 @@ class GPTQ:
|
||||
print(table.draw().split("\n")[-2])
|
||||
|
||||
def fasterquant(
|
||||
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
|
||||
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
|
||||
):
|
||||
self.layer.to(self.dev)
|
||||
|
||||
@ -263,7 +263,7 @@ class GPTQ:
|
||||
H[dead, dead] = 1
|
||||
W[:, dead] = 0
|
||||
|
||||
if actorder:
|
||||
if act_order:
|
||||
perm = torch.argsort(torch.diag(H), descending=True)
|
||||
W = W[:, perm]
|
||||
H = H[perm][:, perm]
|
||||
@ -334,7 +334,7 @@ class GPTQ:
|
||||
groupsize = groupsize if groupsize != -1 else self.columns
|
||||
g_idx = [i // groupsize for i in range(self.columns)]
|
||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||
if actorder:
|
||||
if act_order:
|
||||
invperm = torch.argsort(perm)
|
||||
Q = Q[:, invperm]
|
||||
g_idx = g_idx[invperm]
|
||||
@ -697,7 +697,7 @@ def sequential(
|
||||
scale, zero, g_idx, error = gptq[name].fasterquant(
|
||||
percdamp=percdamp,
|
||||
groupsize=groupsize,
|
||||
actorder=act_order,
|
||||
act_order=act_order,
|
||||
name=name,
|
||||
)
|
||||
quantizers[f"{prefix}.{i}.{name}"] = (
|
||||
@ -775,6 +775,8 @@ def quantize(
|
||||
output_dir: str,
|
||||
trust_remote_code: bool,
|
||||
upload_to_model_id: Optional[str],
|
||||
percdamp: float,
|
||||
act_order: bool,
|
||||
):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -795,7 +797,16 @@ def quantize(
|
||||
)
|
||||
|
||||
tick = time.time()
|
||||
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
||||
quantizers = sequential(
|
||||
model,
|
||||
dataloader,
|
||||
DEV,
|
||||
nsamples,
|
||||
bits,
|
||||
groupsize,
|
||||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
)
|
||||
print(time.time() - tick)
|
||||
|
||||
pack(model, quantizers, bits, groupsize)
|
||||
|
Loading…
Reference in New Issue
Block a user