text-generation-inference/server/tests/utils/test_convert.py
Nicolas Patry e943a294bc
fix(server): harden the weights choice to save on disk. (#561)
- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).
  
- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593
2023-07-07 14:50:12 +02:00

22 lines
659 B
Python

from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])