From 57433201b2339187c277f2cf514822c102eeba5d Mon Sep 17 00:00:00 2001 From: zhangsibo1129 Date: Tue, 26 Sep 2023 17:57:59 +0800 Subject: [PATCH] Fix shared weights load bug and T5 loading --- server/text_generation_server/models/t5.py | 1 + server/text_generation_server/utils/weights.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..58be842d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM): "shared.weight": [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight", ] }, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 266fcccb..20dbc23f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -45,9 +45,9 @@ class Weights: def get_filename(self, tensor_name: str) -> (str, str): filename = self.routing.get(tensor_name, None) if filename is None: - aliases = self.aliases.get(tensor_name, []) - for alias in aliases: - filename = self.routing.get(alias, None) + for alias, tensor_list in self.aliases.items(): + if tensor_name in tensor_list: + filename = self.routing.get(alias, None) if filename is not None: return str(filename), alias raise RuntimeError(f"weight {tensor_name} does not exist")