fix compatibility version issue

This commit is contained in:
Cyril Vallez 2025-01-17 17:04:56 +00:00
parent 42ae6dea02
commit f01014de37
No known key found for this signature in database

View File

@ -956,15 +956,22 @@ def quantize(
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB"
shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size,
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(