Fixup opt to reduce the amount of odd if statements.

This commit is contained in:
Nicolas Patry 2024-12-12 15:03:40 +01:00
parent bf59118a93
commit 8b81f72b0f
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
"""PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, prefix, config, weights)
OPTDecoderLayer(
layer_id,
prefix=f"{prefix}decoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)