diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 7dfe6a1e..ba6c5702 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -14,7 +14,7 @@ def test_convert_files(): local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] - convert_files(local_pt_files, local_st_files) + convert_files(local_pt_files, local_st_files, discard_names=[]) found_st_files = weight_files(model_id) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 3463049a..7a55e919 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -160,8 +160,26 @@ def download_weights( p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] + try: + from transformers import AutoConfig + import transformers + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception as e: + discard_names = [] # Convert pytorch weights to safetensors - utils.convert_files(local_pt_files, local_st_files) + utils.convert_files(local_pt_files, local_st_files, discard_names) @app.command() diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 0e4adaba..305263ba 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -4,11 +4,56 @@ import os from loguru import logger from pathlib import Path -from safetensors.torch import save_file, _remove_duplicate_names, load_file -from typing import List +from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete +from typing import List, Dict +from collections import defaultdict -def convert_file(pt_file: Path, sf_file: Path): +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])] + ) + if not complete_names: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): """ Convert a pytorch file to a safetensors file This will remove duplicate tensors from the file. @@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path): loaded = torch.load(pt_file, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] - to_removes = _remove_duplicate_names(loaded) + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): @@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path): raise RuntimeError(f"The output tensors do not match for key {k}") -def convert_files(pt_files: List[Path], sf_files: List[Path]): +def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]): assert len(pt_files) == len(sf_files) N = len(pt_files) @@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): start = datetime.datetime.now() - convert_file(pt_file, sf_file) + convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")