mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixup opt to reduce the amount of odd if statements.
This commit is contained in:
parent
bf59118a93
commit
8b81f72b0f
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch OPT model."""
|
"""PyTorch OPT model."""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
|
|
||||||
self.self_attn = OPTAttention(
|
self.self_attn = OPTAttention(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
OPTDecoderLayer(
|
||||||
|
layer_id,
|
||||||
|
prefix=f"{prefix}decoder.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user