diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 2fc7c53d..195b3883 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -54,7 +54,10 @@ class FlashRWSharded(FlashCausalLM): device, dtype, process_group=self.process_group, - aliases={"lm_head.weight": ["transformer.word_embeddings.weight"]}, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, ) config.quantize = quantize