From 918bea23cb563d60ca1b06d50b4a3e99488fec63 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 11 Dec 2024 02:53:20 -0800 Subject: [PATCH] fix facebook/opt-125m not working issue Signed-off-by: Wang, Yi A --- .../models/custom_modeling/opt_modeling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index bd440321..a6348b5b 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -99,7 +99,7 @@ class OPTLearnedPositionalEmbedding(nn.Module): self.offset = 2 self.weight = nn.Parameter( weights.get_tensor( - f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + f"{prefix if prefix else ''}decoder.embed_positions.weight" ) ) @@ -317,7 +317,7 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" + prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -755,6 +755,8 @@ class OPTModel(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix and any(s.startswith("model") for s in weights.routing.keys()): + prefix = "model" self.model = OPTModel(prefix, config, weights)