mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Fixup opt to reduce the amount of odd if statements. (#2833)
* Fixup opt to reduce the amount of odd if statements. * Fixing cargo lock
This commit is contained in:
parent
bf59118a93
commit
3bb3fd19ae
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -4367,7 +4367,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4391,7 +4391,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
@ -4411,7 +4411,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
@ -4429,7 +4429,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.21",
|
"clap 4.5.21",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -4450,7 +4450,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
@ -4501,7 +4501,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v2"
|
name = "text-generation-router-v2"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@ -4550,7 +4550,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v3"
|
||||||
version = "3.0.1-dev0"
|
version = "3.0.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -12,7 +12,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch OPT model."""
|
"""PyTorch OPT model."""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
prefix = f"{prefix if prefix else ''}decoder.layers.{layer_id}"
|
|
||||||
self.self_attn = OPTAttention(
|
self.self_attn = OPTAttention(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
OPTDecoderLayer(
|
||||||
|
layer_id,
|
||||||
|
prefix=f"{prefix}decoder.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user