fea(server): decrease convert RAM requirements (#286)

This commit is contained in:
Nicolas Patry 2023-05-05 17:57:02 +02:00 committed by GitHub
parent 3314a46d36
commit b4aa87db58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,6 +9,7 @@ from datetime import timedelta
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from safetensors import safe_open
from typing import Dict, List from typing import Dict, List
@ -46,11 +47,11 @@ def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
tensors.pop(name) tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path): def convert_file(pt_file: Path, sf_file: Path):
""" """
Convert a pytorch file to a safetensors file Convert a pytorch file to a safetensors file
""" """
logger.info(f"Convert {pt_file} to {st_file}.") logger.info(f"Convert {pt_file} to {sf_file}.")
pt_state = torch.load(pt_file, map_location="cpu") pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state: if "state_dict" in pt_state:
@ -61,28 +62,28 @@ def convert_file(pt_file: Path, st_file: Path):
# Tensors need to be contiguous # Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()} pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True) sf_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"}) save_file(pt_state, str(sf_file), metadata={"format": "pt"})
# Check that both files are close in size # Check that both files are close in size
check_file_size(pt_file, st_file) check_file_size(pt_file, sf_file)
# Load safetensors state # Load safetensors state
st_state = load_file(str(st_file)) for k in pt_state:
for k in st_state:
pt_tensor = pt_state[k] pt_tensor = pt_state[k]
st_tensor = st_state[k] with safe_open(sf_file, framework="pt") as f:
if not torch.equal(pt_tensor, st_tensor): sf_tensor = f.get_tensor(k)
raise RuntimeError(f"The output tensors do not match for key {k}") if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]): def convert_files(pt_files: List[Path], sf_files: List[Path]):
assert len(pt_files) == len(st_files) assert len(pt_files) == len(sf_files)
N = len(pt_files) N = len(pt_files)
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
for i, (pt_file, sf_file) in enumerate(zip(pt_files, st_files)): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
start = datetime.datetime.now() start = datetime.datetime.now()
convert_file(pt_file, sf_file) convert_file(pt_file, sf_file)
elapsed = datetime.datetime.now() - start elapsed = datetime.datetime.now() - start