import datetime import torch from collections import defaultdict from loguru import logger from pathlib import Path from safetensors.torch import save_file from safetensors import safe_open from typing import Dict, List def check_file_size(source_file: Path, target_file: Path): """ Check that two files are close in size """ source_file_size = source_file.stat().st_size target_file_size = target_file.stat().st_size if (source_file_size - target_file_size) / source_file_size > 0.01: raise RuntimeError( f"""The file size different is more than 1%: - {source_file}: {source_file_size} - {target_file}: {target_file_size} """ ) def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): """ For a Dict of tensors, check if two or more tensors point to the same underlying memory and remove them """ ptrs = defaultdict(list) for k, v in tensors.items(): ptrs[v.data_ptr()].append(k) # Iterate over all found memory addresses for ptr, names in ptrs.items(): if len(names) > 1: # Multiple tensors are point to the same memory # Only keep the first tensor for name in names[1:]: tensors.pop(name) def convert_file(pt_file: Path, sf_file: Path): """ Convert a pytorch file to a safetensors file """ logger.info(f"Convert {pt_file} to {sf_file}.") pt_state = torch.load(pt_file, map_location="cpu") if "state_dict" in pt_state: pt_state = pt_state["state_dict"] remove_shared_pointers(pt_state) # Tensors need to be contiguous pt_state = {k: v.contiguous() for k, v in pt_state.items()} sf_file.parent.mkdir(parents=True, exist_ok=True) save_file(pt_state, str(sf_file), metadata={"format": "pt"}) # Check that both files are close in size check_file_size(pt_file, sf_file) # Load safetensors state for k in pt_state: pt_tensor = pt_state[k] with safe_open(sf_file, framework="pt") as f: sf_tensor = f.get_tensor(k) if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"The output tensors do not match for key {k}") def convert_files(pt_files: List[Path], sf_files: List[Path]): assert len(pt_files) == len(sf_files) N = len(pt_files) # We do this instead of using tqdm because we want to parse the logs with the launcher for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): start = datetime.datetime.now() convert_file(pt_file, sf_file) elapsed = datetime.datetime.now() - start logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")