mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix compatibility version issue
This commit is contained in:
parent
42ae6dea02
commit
f01014de37
@ -956,15 +956,22 @@ def quantize(
|
|||||||
|
|
||||||
pack(model, quantizers, bits, groupsize)
|
pack(model, quantizers, bits, groupsize)
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
from transformers.modeling_utils import shard_checkpoint
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||||
|
|
||||||
max_shard_size = "10GB"
|
max_shard_size = "10GB"
|
||||||
shards, index = shard_checkpoint(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
state_dict, filename_pattern="model.safetensors", max_shard_size=max_shard_size,
|
||||||
)
|
)
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
shards = state_dict_split.filename_to_tensors
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
save_file(
|
save_file(
|
||||||
|
Loading…
Reference in New Issue
Block a user