From 742199aa0d9b4d4dc410f90572f78e9a5d776265 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Jul 2023 11:30:59 +0200 Subject: [PATCH] Modified fix. --- server/text_generation_server/models/flash_rw.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 5f963bfb..33079ace 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -48,7 +48,13 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]}, + ) config.quantize = quantize