tied embeddings for qwe2

This commit is contained in:
matvey-kolbasov-hs 2024-07-24 11:59:56 +03:00
parent 8642250602
commit f73f57ca21

View File

@ -959,6 +959,10 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
aliases={
"lm_head.weight": ["model.word_embeddings.weight"],
"model.word_embeddings.weight": ["lm_head.weight"],
}
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))