mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Changing convert logic.
Should be more robust to shared tensors (ok when using `from_pretrained). But forcing us to add new checks in our loading code (since the chosen key to keep might be different from `transformers`).
This commit is contained in:
parent
c9c65ab323
commit
34eadb54e9
@ -1,76 +1,45 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file, _remove_duplicate_names, load_file
|
||||||
from safetensors import safe_open
|
from typing import List
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
|
|
||||||
def check_file_size(source_file: Path, target_file: Path):
|
|
||||||
"""
|
|
||||||
Check that two files are close in size
|
|
||||||
"""
|
|
||||||
source_file_size = source_file.stat().st_size
|
|
||||||
target_file_size = target_file.stat().st_size
|
|
||||||
|
|
||||||
if (source_file_size - target_file_size) / source_file_size > 0.05:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"""The file size different is more than 5%:
|
|
||||||
- {source_file}: {source_file_size}
|
|
||||||
- {target_file}: {target_file_size}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
|
|
||||||
"""
|
|
||||||
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
|
|
||||||
remove them
|
|
||||||
"""
|
|
||||||
ptrs = defaultdict(list)
|
|
||||||
for k, v in tensors.items():
|
|
||||||
ptrs[v.data_ptr()].append(k)
|
|
||||||
|
|
||||||
# Iterate over all found memory addresses
|
|
||||||
for ptr, names in ptrs.items():
|
|
||||||
if len(names) > 1:
|
|
||||||
# Multiple tensors are point to the same memory
|
|
||||||
# Only keep the first tensor
|
|
||||||
for name in names[1:]:
|
|
||||||
tensors.pop(name)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_file(pt_file: Path, sf_file: Path):
|
def convert_file(pt_file: Path, sf_file: Path):
|
||||||
"""
|
"""
|
||||||
Convert a pytorch file to a safetensors file
|
Convert a pytorch file to a safetensors file
|
||||||
|
This will remove duplicate tensors from the file.
|
||||||
|
|
||||||
|
Unfortunately, this might not respect *transformers* convention.
|
||||||
|
Forcing us to check for potentially different keys during load when looking
|
||||||
|
for specific tensors (making tensor sharing explicit).
|
||||||
"""
|
"""
|
||||||
logger.info(f"Convert {pt_file} to {sf_file}.")
|
loaded = torch.load(pt_file, map_location="cpu")
|
||||||
|
if "state_dict" in loaded:
|
||||||
|
loaded = loaded["state_dict"]
|
||||||
|
to_removes = _remove_duplicate_names(loaded)
|
||||||
|
|
||||||
pt_state = torch.load(pt_file, map_location="cpu")
|
metadata = {"format": "pt"}
|
||||||
if "state_dict" in pt_state:
|
for kept_name, to_remove_group in to_removes.items():
|
||||||
pt_state = pt_state["state_dict"]
|
for to_remove in to_remove_group:
|
||||||
|
if to_remove not in metadata:
|
||||||
|
metadata[to_remove] = kept_name
|
||||||
|
del loaded[to_remove]
|
||||||
|
# Force tensors to be contiguous
|
||||||
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
remove_shared_pointers(pt_state)
|
dirname = os.path.dirname(sf_file)
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
# Tensors need to be contiguous
|
save_file(loaded, sf_file, metadata=metadata)
|
||||||
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
reloaded = load_file(sf_file)
|
||||||
|
for k in loaded:
|
||||||
sf_file.parent.mkdir(parents=True, exist_ok=True)
|
pt_tensor = loaded[k]
|
||||||
save_file(pt_state, str(sf_file), metadata={"format": "pt"})
|
sf_tensor = reloaded[k]
|
||||||
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
# Check that both files are close in size
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
check_file_size(pt_file, sf_file)
|
|
||||||
|
|
||||||
# Load safetensors state
|
|
||||||
for k in pt_state:
|
|
||||||
pt_tensor = pt_state[k]
|
|
||||||
with safe_open(sf_file, framework="pt") as f:
|
|
||||||
sf_tensor = f.get_tensor(k)
|
|
||||||
if not torch.equal(pt_tensor, sf_tensor):
|
|
||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||||
|
Loading…
Reference in New Issue
Block a user