From 34eadb54e9b54fc207694c87dad38e91d09cf486 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 22 Jun 2023 14:52:47 +0200 Subject: [PATCH] Changing convert logic. Should be more robust to shared tensors (ok when using `from_pretrained). But forcing us to add new checks in our loading code (since the chosen key to keep might be different from `transformers`). --- .../text_generation_server/utils/convert.py | 89 ++++++------------- 1 file changed, 29 insertions(+), 60 deletions(-) diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index c4e79432..0e4adaba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -1,76 +1,45 @@ import datetime import torch +import os -from collections import defaultdict from loguru import logger from pathlib import Path -from safetensors.torch import save_file -from safetensors import safe_open -from typing import Dict, List - - -def check_file_size(source_file: Path, target_file: Path): - """ - Check that two files are close in size - """ - source_file_size = source_file.stat().st_size - target_file_size = target_file.stat().st_size - - if (source_file_size - target_file_size) / source_file_size > 0.05: - raise RuntimeError( - f"""The file size different is more than 5%: - - {source_file}: {source_file_size} - - {target_file}: {target_file_size} - """ - ) - - -def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): - """ - For a Dict of tensors, check if two or more tensors point to the same underlying memory and - remove them - """ - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - - # Iterate over all found memory addresses - for ptr, names in ptrs.items(): - if len(names) > 1: - # Multiple tensors are point to the same memory - # Only keep the first tensor - for name in names[1:]: - tensors.pop(name) +from safetensors.torch import save_file, _remove_duplicate_names, load_file +from typing import List def convert_file(pt_file: Path, sf_file: Path): """ Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). """ - logger.info(f"Convert {pt_file} to {sf_file}.") + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded) - pt_state = torch.load(pt_file, map_location="cpu") - if "state_dict" in pt_state: - pt_state = pt_state["state_dict"] + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} - remove_shared_pointers(pt_state) - - # Tensors need to be contiguous - pt_state = {k: v.contiguous() for k, v in pt_state.items()} - - sf_file.parent.mkdir(parents=True, exist_ok=True) - save_file(pt_state, str(sf_file), metadata={"format": "pt"}) - - # Check that both files are close in size - check_file_size(pt_file, sf_file) - - # Load safetensors state - for k in pt_state: - pt_tensor = pt_state[k] - with safe_open(sf_file, framework="pt") as f: - sf_tensor = f.get_tensor(k) - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") def convert_files(pt_files: List[Path], sf_files: List[Path]):