Fixup opt to reduce the amount of odd if statements. (#2833)

* Fixup opt to reduce the amount of odd if statements.

* Fixing cargo lock
This commit is contained in:
Nicolas Patry 2024-12-12 18:20:13 +01:00 committed by GitHub
parent bf59118a93
commit 3bb3fd19ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 10 deletions

14
Cargo.lock generated
View File

@ -4367,7 +4367,7 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4391,7 +4391,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"average",
"clap 4.5.21",
@ -4411,7 +4411,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4429,7 +4429,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"clap 4.5.21",
"ctrlc",
@ -4450,7 +4450,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"anyhow",
"async-stream",
@ -4501,7 +4501,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v2"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"async-stream",
"async-trait",
@ -4550,7 +4550,7 @@ dependencies = [
[[package]]
name = "text-generation-router-v3"
version = "3.0.1-dev0"
version = "3.0.2-dev0"
dependencies = [
"async-stream",
"async-trait",

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, prefix, config, weights)
OPTDecoderLayer(
layer_id,
prefix=f"{prefix}decoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)