From 3bb3fd19ae115484878e421649dba4b9549aa42a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 12 Dec 2024 18:20:13 +0100 Subject: [PATCH] Fixup opt to reduce the amount of odd if statements. (#2833) * Fixup opt to reduce the amount of odd if statements. * Fixing cargo lock --- Cargo.lock | 14 +++++++------- .../models/custom_modeling/opt_modeling.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9551ae2d9..f0b756f9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4367,7 +4367,7 @@ dependencies = [ [[package]] name = "text-generation-backends-trtllm" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4391,7 +4391,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "average", "clap 4.5.21", @@ -4411,7 +4411,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4429,7 +4429,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "clap 4.5.21", "ctrlc", @@ -4450,7 +4450,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "anyhow", "async-stream", @@ -4501,7 +4501,7 @@ dependencies = [ [[package]] name = "text-generation-router-v2" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4550,7 +4550,7 @@ dependencies = [ [[package]] name = "text-generation-router-v3" -version = "3.0.1-dev0" +version = "3.0.2-dev0" dependencies = [ "async-stream", "async-trait", diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index a6348b5b9..db73ae84e 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch OPT model.""" +"""PyTorch OPT model.""" + import random from typing import List, Optional, Tuple, Union @@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel): self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, prefix, config, weights) + OPTDecoderLayer( + layer_id, + prefix=f"{prefix}decoder.layers.{layer_id}", + config=config, + weights=weights, + ) for layer_id in range(config.num_hidden_layers) ] )