diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1cd13a2a..a1ccdb12 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -959,6 +959,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + aliases={ + "lm_head.weight": ["model.word_embeddings.weight"], + "model.word_embeddings.weight": ["lm_head.weight"], + } ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))