mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
Fixup opt to reduce the amount of odd if statements. (#2833)
* Fixup opt to reduce the amount of odd if statements. * Fixing cargo lock
This commit is contained in:
parent
bf59118a93
commit
3bb3fd19ae
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -4367,7 +4367,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4391,7 +4391,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.21",
|
||||
@ -4411,7 +4411,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
@ -4429,7 +4429,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.21",
|
||||
"ctrlc",
|
||||
@ -4450,7 +4450,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
@ -4501,7 +4501,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@ -4550,7 +4550,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "3.0.1-dev0"
|
||||
version = "3.0.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -12,7 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch OPT model."""
|
||||
"""PyTorch OPT model."""
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.hidden_size = config.hidden_size
|
||||
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
||||
OPTDecoderLayer(
|
||||
layer_id,
|
||||
prefix=f"{prefix}decoder.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user