Simpler fix.

This commit is contained in:
Nicolas Patry 2023-09-26 13:03:45 +00:00
parent 649d9754b1
commit 9c0f679d1d

View File

@ -186,11 +186,7 @@ def download_weights(
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = []
if getattr(class_, "_tied_weights_keys", []):
discard_names.extend(getattr(class_, "_tied_weights_keys", []))
if getattr(class_, "_keys_to_ignore_on_load_missing", []):
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception as e:
discard_names = []