fix discard_names in safetensors convertion

This commit is contained in:
zhangsibo1129 2023-09-25 10:43:28 +08:00
parent 123749a3c9
commit 649d9754b1

View File

@ -186,7 +186,10 @@ def download_weights(
class_ = getattr(transformers, architecture) class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version. # Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", []) discard_names = []
if getattr(class_, "_tied_weights_keys", []):
discard_names.extend(getattr(class_, "_tied_weights_keys", []))
if getattr(class_, "_keys_to_ignore_on_load_missing", []):
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
except Exception as e: except Exception as e: