diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 1bdedf58b..caf1a764a 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -9,6 +9,7 @@ from datetime import timedelta from loguru import logger from pathlib import Path from safetensors.torch import load_file, save_file +from safetensors import safe_open from typing import Dict, List @@ -46,11 +47,11 @@ def remove_shared_pointers(tensors: Dict[str, torch.Tensor]): tensors.pop(name) -def convert_file(pt_file: Path, st_file: Path): +def convert_file(pt_file: Path, sf_file: Path): """ Convert a pytorch file to a safetensors file """ - logger.info(f"Convert {pt_file} to {st_file}.") + logger.info(f"Convert {pt_file} to {sf_file}.") pt_state = torch.load(pt_file, map_location="cpu") if "state_dict" in pt_state: @@ -61,28 +62,28 @@ def convert_file(pt_file: Path, st_file: Path): # Tensors need to be contiguous pt_state = {k: v.contiguous() for k, v in pt_state.items()} - st_file.parent.mkdir(parents=True, exist_ok=True) - save_file(pt_state, str(st_file), metadata={"format": "pt"}) + sf_file.parent.mkdir(parents=True, exist_ok=True) + save_file(pt_state, str(sf_file), metadata={"format": "pt"}) # Check that both files are close in size - check_file_size(pt_file, st_file) + check_file_size(pt_file, sf_file) # Load safetensors state - st_state = load_file(str(st_file)) - for k in st_state: + for k in pt_state: pt_tensor = pt_state[k] - st_tensor = st_state[k] - if not torch.equal(pt_tensor, st_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") + with safe_open(sf_file, framework="pt") as f: + sf_tensor = f.get_tensor(k) + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") -def convert_files(pt_files: List[Path], st_files: List[Path]): - assert len(pt_files) == len(st_files) +def convert_files(pt_files: List[Path], sf_files: List[Path]): + assert len(pt_files) == len(sf_files) N = len(pt_files) # We do this instead of using tqdm because we want to parse the logs with the launcher - - for i, (pt_file, sf_file) in enumerate(zip(pt_files, st_files)): + + for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): start = datetime.datetime.now() convert_file(pt_file, sf_file) elapsed = datetime.datetime.now() - start