diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..58be842d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -63,6 +63,7 @@ class T5Sharded(Seq2SeqLM): "shared.weight": [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight", ] }, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 266fcccb..20dbc23f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -45,9 +45,9 @@ class Weights: def get_filename(self, tensor_name: str) -> (str, str): filename = self.routing.get(tensor_name, None) if filename is None: - aliases = self.aliases.get(tensor_name, []) - for alias in aliases: - filename = self.routing.get(alias, None) + for alias, tensor_list in self.aliases.items(): + if tensor_name in tensor_list: + filename = self.routing.get(alias, None) if filename is not None: return str(filename), alias raise RuntimeError(f"weight {tensor_name} does not exist")