Fixup opt to reduce the amount of odd if statements. (#2833)

* Fixup opt to reduce the amount of odd if statements.

* Fixing cargo lock
This commit is contained in:
Nicolas Patry 2024-12-12 18:20:13 +01:00 committed by GitHub
parent bf59118a93
commit 3bb3fd19ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 10 deletions

14
Cargo.lock generated
View File

@ -4367,7 +4367,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4391,7 +4391,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.21", "clap 4.5.21",
@ -4411,7 +4411,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -4429,7 +4429,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"clap 4.5.21", "clap 4.5.21",
"ctrlc", "ctrlc",
@ -4450,7 +4450,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
@ -4501,7 +4501,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4550,7 +4550,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "3.0.1-dev0" version = "3.0.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",

View File

@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch OPT model.""" """PyTorch OPT model."""
import random import random
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OPTDecoderLayer(layer_id, prefix, config, weights) OPTDecoderLayer(
layer_id,
prefix=f"{prefix}decoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )