diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 12b862d78..55d555fcc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]}, + ) config.quantize = quantize