mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
- Look at `transformers` base class to check for `_key_to_ignore_on_load_missing` or `_tied_weights` which are the standard attributes to select the keys to NOT save on disk (since they are ignored) - Modified safetensors code (to be reflected in safetensors even if it's an internal function). - Will not work for trust_remote_code=True repos (like santacoder). Should help with : https://github.com/huggingface/text-generation-inference/issues/555 and : https://github.com/huggingface/text-generation-inference/pull/501 and https://github.com/huggingface/text-generation-inference/issues/556 and https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593
22 lines
659 B
Python
22 lines
659 B
Python
from text_generation_server.utils.hub import (
|
|
download_weights,
|
|
weight_hub_files,
|
|
weight_files,
|
|
)
|
|
|
|
from text_generation_server.utils.convert import convert_files
|
|
|
|
|
|
def test_convert_files():
|
|
model_id = "bigscience/bloom-560m"
|
|
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
|
local_pt_files = download_weights(pt_filenames, model_id)
|
|
local_st_files = [
|
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
|
]
|
|
convert_files(local_pt_files, local_st_files, discard_names=[])
|
|
|
|
found_st_files = weight_files(model_id)
|
|
|
|
assert all([p in found_st_files for p in local_st_files])
|