fix aliases

This commit is contained in:
matvey-kolbasov-hs 2024-07-24 16:05:38 +03:00
parent eabcb2967a
commit fbb683fce7

View File

@ -950,7 +950,6 @@ def get_model(
if model_type == QWEN2: if model_type == QWEN2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
print('!!! aliases !!!')
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
model_class=Qwen2ForCausalLM, model_class=Qwen2ForCausalLM,
@ -961,8 +960,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
aliases={ aliases={
"lm_head.weight": ["model.word_embeddings.weight"], "lm_head.weight": ["model.embed_tokens.weight"]
"model.word_embeddings.weight": ["lm_head.weight"],
} }
) )
elif sharded: elif sharded: