mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
fea(server): decrease convert RAM requirements (#286)
This commit is contained in:
parent
3314a46d36
commit
b4aa87db58
@ -9,6 +9,7 @@ from datetime import timedelta
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
from safetensors import safe_open
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +47,11 @@ def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
|
|||||||
tensors.pop(name)
|
tensors.pop(name)
|
||||||
|
|
||||||
|
|
||||||
def convert_file(pt_file: Path, st_file: Path):
|
def convert_file(pt_file: Path, sf_file: Path):
|
||||||
"""
|
"""
|
||||||
Convert a pytorch file to a safetensors file
|
Convert a pytorch file to a safetensors file
|
||||||
"""
|
"""
|
||||||
logger.info(f"Convert {pt_file} to {st_file}.")
|
logger.info(f"Convert {pt_file} to {sf_file}.")
|
||||||
|
|
||||||
pt_state = torch.load(pt_file, map_location="cpu")
|
pt_state = torch.load(pt_file, map_location="cpu")
|
||||||
if "state_dict" in pt_state:
|
if "state_dict" in pt_state:
|
||||||
@ -61,28 +62,28 @@ def convert_file(pt_file: Path, st_file: Path):
|
|||||||
# Tensors need to be contiguous
|
# Tensors need to be contiguous
|
||||||
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
||||||
|
|
||||||
st_file.parent.mkdir(parents=True, exist_ok=True)
|
sf_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
save_file(pt_state, str(st_file), metadata={"format": "pt"})
|
save_file(pt_state, str(sf_file), metadata={"format": "pt"})
|
||||||
|
|
||||||
# Check that both files are close in size
|
# Check that both files are close in size
|
||||||
check_file_size(pt_file, st_file)
|
check_file_size(pt_file, sf_file)
|
||||||
|
|
||||||
# Load safetensors state
|
# Load safetensors state
|
||||||
st_state = load_file(str(st_file))
|
for k in pt_state:
|
||||||
for k in st_state:
|
|
||||||
pt_tensor = pt_state[k]
|
pt_tensor = pt_state[k]
|
||||||
st_tensor = st_state[k]
|
with safe_open(sf_file, framework="pt") as f:
|
||||||
if not torch.equal(pt_tensor, st_tensor):
|
sf_tensor = f.get_tensor(k)
|
||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
def convert_files(pt_files: List[Path], st_files: List[Path]):
|
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||||
assert len(pt_files) == len(st_files)
|
assert len(pt_files) == len(sf_files)
|
||||||
|
|
||||||
N = len(pt_files)
|
N = len(pt_files)
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
|
||||||
for i, (pt_file, sf_file) in enumerate(zip(pt_files, st_files)):
|
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
||||||
start = datetime.datetime.now()
|
start = datetime.datetime.now()
|
||||||
convert_file(pt_file, sf_file)
|
convert_file(pt_file, sf_file)
|
||||||
elapsed = datetime.datetime.now() - start
|
elapsed = datetime.datetime.now() - start
|
||||||
|
Loading…
Reference in New Issue
Block a user