mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 07:52:06 +00:00
Fixing few things
This commit is contained in:
parent
dadbbc27d5
commit
ffe8fc4699
@ -150,7 +150,6 @@ def download_weights(
|
|||||||
# Convert pytorch weights to safetensors
|
# Convert pytorch weights to safetensors
|
||||||
utils.convert_files(local_pt_files, local_st_files)
|
utils.convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def quantize(
|
def quantize(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -158,8 +157,9 @@ def quantize(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
extension: str = (".safetensors",)
|
extension: str = ".safetensors",
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
@ -171,15 +171,12 @@ def quantize(
|
|||||||
backtrace=True,
|
backtrace=True,
|
||||||
diagnose=False,
|
diagnose=False,
|
||||||
)
|
)
|
||||||
download_weights(
|
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
logger_level=logger_level,
|
|
||||||
json_output=json_output,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.gptq.quantize import quantize
|
from text_generation_server.utils.gptq.quantize import quantize
|
||||||
|
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -304,11 +304,14 @@ class QuantLinearFunction(torch.autograd.Function):
|
|||||||
class QuantLinear(nn.Module):
|
class QuantLinear(nn.Module):
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.qweight = qweight
|
self.qweight = self.register_buffer("qweight", qweight)
|
||||||
self.qzeros = qzeros
|
self.qzeros = self.register_buffer("qzeros", qzeros)
|
||||||
self.scales = scales
|
self.scales = self.register_buffer("scales", scales)
|
||||||
self.g_idx = g_idx
|
self.g_idx = self.register_buffer("g_idx", g_idx)
|
||||||
self.bias = bias
|
if bias is not None:
|
||||||
|
self.bias = self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
if bits not in [2, 4, 8]:
|
if bits not in [2, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
@ -937,9 +937,9 @@ def pack(model, quantizers, bits, groupsize):
|
|||||||
# print('max memory(MiB):', max_memory)
|
# print('max memory(MiB):', max_memory)
|
||||||
|
|
||||||
|
|
||||||
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
|
def quantize(model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool):
|
||||||
print("loading model")
|
print("loading model")
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0")
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="balanced_low_0", trust_remote_code=trust_remote_code)
|
||||||
print("LOADED model")
|
print("LOADED model")
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
|
|
||||||
@ -1002,8 +1002,8 @@ def quantize(model_id: str, bits: int, groupsize: int, output_dir: str):
|
|||||||
from transformers.modeling_utils import shard_checkpoint
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
state_dict["gptq_bits"] = torch.LongTensor(bits)
|
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||||
state_dict["gptq_groupsize"] = torch.LongTensor(groupsize)
|
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors")
|
||||||
|
Loading…
Reference in New Issue
Block a user