mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-21 16:40:20 +00:00
Merge branch 'main' into lewtun-patch-1
This commit is contained in:
commit
c07acd4fea
17
.github/workflows/build.yaml
vendored
17
.github/workflows/build.yaml
vendored
@ -8,6 +8,15 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- 'v*'
|
- 'v*'
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- ".github/workflows/build.yaml"
|
||||||
|
- "server/**"
|
||||||
|
- "proto/**"
|
||||||
|
- "router/**"
|
||||||
|
- "launcher/**"
|
||||||
|
- "Cargo.lock"
|
||||||
|
- "rust-toolchain.toml"
|
||||||
|
- "Dockerfile"
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
|
||||||
@ -15,6 +24,10 @@ jobs:
|
|||||||
build-and-push-image:
|
build-and-push-image:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
- name: Initialize Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2.0.0
|
||||||
|
with:
|
||||||
|
install: true
|
||||||
- name: Tailscale
|
- name: Tailscale
|
||||||
uses: tailscale/github-action@v1
|
uses: tailscale/github-action@v1
|
||||||
with:
|
with:
|
||||||
@ -65,5 +78,5 @@ jobs:
|
|||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
cache-from: type=registry,ref=ghcr.io/huggingface/text-generation-inference:latest
|
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
||||||
cache-to: type=inline
|
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
|
40
.github/workflows/tests.yaml
vendored
40
.github/workflows/tests.yaml
vendored
@ -3,14 +3,23 @@ name: Server Tests
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
|
- ".github/workflows/tests.yaml"
|
||||||
- "server/**"
|
- "server/**"
|
||||||
- "proto/**"
|
- "proto/**"
|
||||||
- "router/**"
|
- "router/**"
|
||||||
- "launcher/**"
|
- "launcher/**"
|
||||||
|
- "Cargo.lock"
|
||||||
|
- "rust-toolchain.toml"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_tests:
|
run_tests:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
|
||||||
|
env:
|
||||||
|
SCCACHE_GHA_ENABLED: "on"
|
||||||
|
RUSTC_WRAPPER: /usr/local/bin/sccache
|
||||||
|
SCCACHE: 0.3.3
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
@ -25,19 +34,38 @@ jobs:
|
|||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
uses: arduino/setup-protoc@v1
|
uses: arduino/setup-protoc@v1
|
||||||
- name: Loading cache.
|
- name: Install sccache
|
||||||
uses: actions/cache@v2
|
run: |
|
||||||
id: model_cache
|
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
|
||||||
|
chmod +x /usr/local/bin/sccache
|
||||||
|
- name: configure sccache
|
||||||
|
uses: actions/github-script@v6
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/huggingface/
|
script: |
|
||||||
key: models
|
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
|
||||||
|
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||||
|
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
|
||||||
|
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
|
||||||
|
- name: cargo registry cache
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
|
||||||
|
restore-keys: |
|
||||||
|
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
|
||||||
|
cargo-${{ runner.os }}-
|
||||||
|
path: |
|
||||||
|
~/.cargo/registry
|
||||||
|
~/.cargo/git
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
make install
|
make install
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
pytest -sv server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
- name: sccache stats
|
||||||
|
run: |
|
||||||
|
/usr/local/bin/sccache --show-stats
|
||||||
|
317
Cargo.lock
generated
317
Cargo.lock
generated
@ -8,6 +8,17 @@ version = "1.0.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ahash"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
"once_cell",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aho-corasick"
|
name = "aho-corasick"
|
||||||
version = "0.7.20"
|
version = "0.7.20"
|
||||||
@ -34,19 +45,20 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stream"
|
name = "async-stream"
|
||||||
version = "0.3.3"
|
version = "0.3.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
|
checksum = "ad445822218ce64be7a341abfb0b1ea43b5c23aa83902542a4542e78309d8e5e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream-impl",
|
"async-stream-impl",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stream-impl"
|
name = "async-stream-impl"
|
||||||
version = "0.3.3"
|
version = "0.3.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
|
checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -83,9 +95,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.6.4"
|
version = "0.6.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
@ -109,7 +121,7 @@ dependencies = [
|
|||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http 0.4.0",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
@ -142,7 +154,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http 0.3.5",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
]
|
]
|
||||||
@ -265,9 +277,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.1.4"
|
version = "4.1.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76"
|
checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
"clap_derive",
|
"clap_derive",
|
||||||
@ -280,9 +292,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_derive"
|
name = "clap_derive"
|
||||||
version = "4.1.0"
|
version = "4.1.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8"
|
checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck",
|
"heck",
|
||||||
"proc-macro-error",
|
"proc-macro-error",
|
||||||
@ -293,9 +305,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_lex"
|
name = "clap_lex"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade"
|
checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"os_str_bytes",
|
"os_str_bytes",
|
||||||
]
|
]
|
||||||
@ -349,9 +361,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-channel"
|
name = "crossbeam-channel"
|
||||||
version = "0.5.6"
|
version = "0.5.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521"
|
checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
@ -359,9 +371,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.8.2"
|
version = "0.8.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc"
|
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"crossbeam-epoch",
|
"crossbeam-epoch",
|
||||||
@ -370,9 +382,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-epoch"
|
name = "crossbeam-epoch"
|
||||||
version = "0.9.13"
|
version = "0.9.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a"
|
checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
@ -383,9 +395,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-utils"
|
name = "crossbeam-utils"
|
||||||
version = "0.8.14"
|
version = "0.8.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f"
|
checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
@ -575,9 +587,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "1.8.0"
|
version = "1.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499"
|
checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"instant",
|
"instant",
|
||||||
]
|
]
|
||||||
@ -774,7 +786,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpc-metadata"
|
name = "grpc-metadata"
|
||||||
version = "0.1.0"
|
version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"tonic",
|
"tonic",
|
||||||
@ -784,9 +796,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.3.15"
|
version = "0.3.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
|
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"fnv",
|
"fnv",
|
||||||
@ -806,6 +818,9 @@ name = "hashbrown"
|
|||||||
version = "0.12.3"
|
version = "0.12.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
@ -839,9 +854,9 @@ checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "0.2.8"
|
version = "0.2.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399"
|
checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"fnv",
|
"fnv",
|
||||||
@ -1004,9 +1019,9 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "is-terminal"
|
name = "is-terminal"
|
||||||
version = "0.4.3"
|
version = "0.4.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "22e18b0a45d56fe973d6db23972bf5bc46f988a4a2385deac9cc29572f09daef"
|
checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"hermit-abi 0.3.1",
|
"hermit-abi 0.3.1",
|
||||||
"io-lifetimes",
|
"io-lifetimes",
|
||||||
@ -1093,6 +1108,15 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mach"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "macro_rules_attribute"
|
name = "macro_rules_attribute"
|
||||||
version = "0.1.3"
|
version = "0.1.3"
|
||||||
@ -1132,13 +1156,71 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memoffset"
|
name = "memoffset"
|
||||||
version = "0.7.1"
|
version = "0.8.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
|
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics"
|
||||||
|
version = "0.20.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
"metrics-macros",
|
||||||
|
"portable-atomic",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-exporter-prometheus"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
|
||||||
|
dependencies = [
|
||||||
|
"hyper",
|
||||||
|
"indexmap",
|
||||||
|
"ipnet",
|
||||||
|
"metrics",
|
||||||
|
"metrics-util",
|
||||||
|
"parking_lot",
|
||||||
|
"portable-atomic",
|
||||||
|
"quanta",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-macros"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-util"
|
||||||
|
version = "0.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"hashbrown",
|
||||||
|
"metrics",
|
||||||
|
"num_cpus",
|
||||||
|
"parking_lot",
|
||||||
|
"portable-atomic",
|
||||||
|
"quanta",
|
||||||
|
"sketches-ddsketch",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mime"
|
name = "mime"
|
||||||
version = "0.3.16"
|
version = "0.3.16"
|
||||||
@ -1172,14 +1254,14 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mio"
|
name = "mio"
|
||||||
version = "0.8.5"
|
version = "0.8.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
|
checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
"windows-sys 0.42.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1268,9 +1350,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.17.0"
|
version = "1.17.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66"
|
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onig"
|
name = "onig"
|
||||||
@ -1514,6 +1596,12 @@ version = "0.3.26"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-atomic"
|
||||||
|
version = "0.3.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.17"
|
version = "0.2.17"
|
||||||
@ -1565,9 +1653,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost"
|
name = "prost"
|
||||||
version = "0.11.6"
|
version = "0.11.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698"
|
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"prost-derive",
|
"prost-derive",
|
||||||
@ -1575,9 +1663,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-build"
|
name = "prost-build"
|
||||||
version = "0.11.6"
|
version = "0.11.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e"
|
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"heck",
|
"heck",
|
||||||
@ -1597,9 +1685,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-derive"
|
name = "prost-derive"
|
||||||
version = "0.11.6"
|
version = "0.11.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d"
|
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"itertools 0.10.5",
|
"itertools 0.10.5",
|
||||||
@ -1610,14 +1698,29 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prost-types"
|
name = "prost-types"
|
||||||
version = "0.11.6"
|
version = "0.11.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788"
|
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
|
||||||
"prost",
|
"prost",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quanta"
|
||||||
|
version = "0.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
"libc",
|
||||||
|
"mach",
|
||||||
|
"once_cell",
|
||||||
|
"raw-cpuid",
|
||||||
|
"wasi 0.10.2+wasi-snapshot-preview1",
|
||||||
|
"web-sys",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.23"
|
version = "1.0.23"
|
||||||
@ -1657,6 +1760,15 @@ dependencies = [
|
|||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "raw-cpuid"
|
||||||
|
version = "10.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.6.1"
|
version = "1.6.1"
|
||||||
@ -1736,15 +1848,6 @@ version = "0.6.28"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
|
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "remove_dir_all"
|
|
||||||
version = "0.5.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
|
|
||||||
dependencies = [
|
|
||||||
"winapi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
version = "0.11.14"
|
version = "0.11.14"
|
||||||
@ -1973,18 +2076,24 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.0"
|
version = "1.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
|
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "sketches-ddsketch"
|
||||||
version = "0.4.7"
|
version = "0.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef"
|
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "slab"
|
||||||
|
version = "0.4.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
@ -1997,9 +2106,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "socket2"
|
name = "socket2"
|
||||||
version = "0.4.7"
|
version = "0.4.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd"
|
checksum = "95a21dcece9b5991cfd1ece74654c8e3d0d5aab499d359b0395e38229c0bb5a3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"winapi",
|
"winapi",
|
||||||
@ -2053,9 +2162,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.107"
|
version = "1.0.109"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5"
|
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -2081,16 +2190,15 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.3.0"
|
version = "3.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
|
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"fastrand",
|
"fastrand",
|
||||||
"libc",
|
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"remove_dir_all",
|
"rustix",
|
||||||
"winapi",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2104,7 +2212,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -2121,9 +2229,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.1.4",
|
"clap 4.1.8",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
"float_eq",
|
"float_eq",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
@ -2136,18 +2244,21 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap 4.1.4",
|
"clap 4.1.8",
|
||||||
"futures",
|
"futures",
|
||||||
|
"metrics",
|
||||||
|
"metrics-exporter-prometheus",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"rand",
|
"rand",
|
||||||
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
@ -2155,6 +2266,7 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tower-http 0.3.5",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
@ -2193,9 +2305,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thread_local"
|
name = "thread_local"
|
||||||
version = "1.1.6"
|
version = "1.1.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "50f297120ff9d4efe680df143d5631bba9c75fa371992b7fcb33eb3453cb0a07"
|
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
@ -2203,12 +2315,11 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.1.45"
|
version = "0.1.43"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a"
|
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"wasi 0.10.0+wasi-snapshot-preview1",
|
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2264,9 +2375,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio"
|
name = "tokio"
|
||||||
version = "1.25.0"
|
version = "1.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af"
|
checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"bytes",
|
"bytes",
|
||||||
@ -2279,7 +2390,7 @@ dependencies = [
|
|||||||
"signal-hook-registry",
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"windows-sys 0.42.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -2315,9 +2426,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-stream"
|
name = "tokio-stream"
|
||||||
version = "0.1.11"
|
version = "0.1.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
|
checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
@ -2326,9 +2437,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.6"
|
version = "0.7.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bc6a3b08b64e6dfad376fa2432c7b1f01522e37a623c3050bc95db2d3ff21583"
|
checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
@ -2417,12 +2528,30 @@ dependencies = [
|
|||||||
"http-body",
|
"http-body",
|
||||||
"http-range-header",
|
"http-range-header",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tower",
|
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-range-header",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
@ -2627,9 +2756,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa"
|
name = "utoipa"
|
||||||
version = "3.0.1"
|
version = "3.0.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
|
checksum = "a15f6da6a2b471134ca44b7d18e8a76d73035cf8b3ed24c4dd5ca6a63aa439c5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"indexmap",
|
"indexmap",
|
||||||
"serde",
|
"serde",
|
||||||
@ -2639,9 +2768,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa-gen"
|
name = "utoipa-gen"
|
||||||
version = "3.0.1"
|
version = "3.0.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
|
checksum = "6f2e33027986a4707b3f5c37ed01b33d0e5a53da30204b52ff18f80600f1d0ec"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro-error",
|
"proc-macro-error",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
@ -2712,9 +2841,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.10.0+wasi-snapshot-preview1"
|
version = "0.10.2+wasi-snapshot-preview1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
|
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
|
37
Dockerfile
37
Dockerfile
@ -1,4 +1,15 @@
|
|||||||
FROM rust:1.67 as router-builder
|
FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
FROM chef as planner
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
COPY launcher launcher
|
||||||
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
|
FROM chef AS builder
|
||||||
|
|
||||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||||
@ -6,26 +17,15 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
|
||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
WORKDIR /usr/src
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
|
RUN cargo chef cook --release --recipe-path recipe.json
|
||||||
|
|
||||||
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY router router
|
COPY router router
|
||||||
|
|
||||||
WORKDIR /usr/src/router
|
|
||||||
|
|
||||||
RUN cargo install --path .
|
|
||||||
|
|
||||||
FROM rust:1.67 as launcher-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
|
RUN cargo build --release
|
||||||
WORKDIR /usr/src/launcher
|
|
||||||
|
|
||||||
RUN cargo install --path .
|
|
||||||
|
|
||||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
|
|
||||||
@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
|
|||||||
LC_ALL=C.UTF-8 \
|
LC_ALL=C.UTF-8 \
|
||||||
DEBIAN_FRONTEND=noninteractive \
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
HUGGINGFACE_HUB_CACHE=/data \
|
HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
MODEL_ID=bigscience/bloom-560m \
|
MODEL_ID=bigscience/bloom-560m \
|
||||||
QUANTIZE=false \
|
QUANTIZE=false \
|
||||||
NUM_SHARD=1 \
|
NUM_SHARD=1 \
|
||||||
@ -68,9 +69,9 @@ RUN cd server && \
|
|||||||
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
|
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
|
||||||
|
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
14
Makefile
14
Makefile
@ -13,25 +13,25 @@ server-dev:
|
|||||||
cd server && make run-dev
|
cd server && make run-dev
|
||||||
|
|
||||||
router-dev:
|
router-dev:
|
||||||
cd router && cargo run
|
cd router && cargo run -- --port 8080
|
||||||
|
|
||||||
integration-tests: install-router install-launcher
|
integration-tests: install-router install-launcher
|
||||||
cargo test
|
cargo test
|
||||||
|
|
||||||
python-tests:
|
python-tests:
|
||||||
cd server && pytest tests
|
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
|
||||||
|
|
||||||
run-bloom-560m:
|
run-bloom-560m:
|
||||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
|
||||||
|
|
||||||
run-bloom-560m-quantize:
|
run-bloom-560m-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
|
||||||
|
|
||||||
download-bloom:
|
download-bloom:
|
||||||
text-generation-server download-weights bigscience/bloom
|
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
|
||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
39
README.md
39
README.md
@ -39,27 +39,30 @@ to power LLMs api-inference widgets.
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
- Serve the most popular Large Language Models with a simple launcher
|
||||||
|
- Tensor Parallelism for faster inference on multiple GPUs
|
||||||
- Token streaming using Server-Sent Events (SSE)
|
- Token streaming using Server-Sent Events (SSE)
|
||||||
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
- Logits warpers (temperature scaling, topk, repetition penalty ...)
|
- Logits warpers (temperature scaling, topk, repetition penalty ...)
|
||||||
- Stop sequences
|
- Stop sequences
|
||||||
- Log probabilities
|
- Log probabilities
|
||||||
- Distributed tracing with Open Telemetry
|
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
|
||||||
|
|
||||||
## Officially supported models
|
## Officially supported architectures
|
||||||
|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||||
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
||||||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||||
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
|
||||||
|
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
|
||||||
|
|
||||||
Other models are supported on a best effort basis using:
|
Other architectures are supported on a best effort basis using:
|
||||||
|
|
||||||
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
|
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
|
||||||
|
|
||||||
@ -80,24 +83,42 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
|
||||||
```
|
```
|
||||||
|
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||||
|
|
||||||
You can then query the model using either the `/generate` or `/generate_stream` routes:
|
You can then query the model using either the `/generate` or `/generate_stream` routes:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl 127.0.0.1:8080/generate \
|
curl 127.0.0.1:8080/generate \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl 127.0.0.1:8080/generate_stream \
|
curl 127.0.0.1:8080/generate_stream \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
|
||||||
-H 'Content-Type: application/json'
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
or from Python:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install text-generation
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import Client
|
||||||
|
|
||||||
|
client = Client("http://127.0.0.1:8080")
|
||||||
|
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
print(text)
|
||||||
|
```
|
||||||
|
|
||||||
### API documentation
|
### API documentation
|
||||||
|
|
||||||
@ -191,7 +212,7 @@ Be aware that the official Docker image has them enabled by default.
|
|||||||
|
|
||||||
### Download
|
### Download
|
||||||
|
|
||||||
First you need to download the weights:
|
It is advised to download the weights ahead of time with the following command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make download-bloom
|
make download-bloom
|
||||||
|
158
clients/python/.gitignore
vendored
Normal file
158
clients/python/.gitignore
vendored
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
text_generation/__pycache__/
|
||||||
|
text_generation/pb/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
transformers
|
||||||
|
safetensors
|
6
clients/python/Makefile
Normal file
6
clients/python/Makefile
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
unit-tests:
|
||||||
|
python -m pytest --cov=text_generation tests
|
||||||
|
|
||||||
|
install:
|
||||||
|
pip install pip --upgrade
|
||||||
|
pip install -e .
|
196
clients/python/README.md
Normal file
196
clients/python/README.md
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# Text Generation
|
||||||
|
|
||||||
|
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
|
||||||
|
`text-generation-inference` instance running on
|
||||||
|
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
|
||||||
|
|
||||||
|
## Get Started
|
||||||
|
|
||||||
|
### Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install text-generation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inference API Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import InferenceAPIClient
|
||||||
|
|
||||||
|
client = InferenceAPIClient("bigscience/bloomz")
|
||||||
|
text = client.generate("Why is the sky blue?").generated_text
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
or with the asynchronous client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import InferenceAPIAsyncClient
|
||||||
|
|
||||||
|
client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||||
|
response = await client.generate("Why is the sky blue?")
|
||||||
|
print(response.generated_text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hugging Face Inference Endpoint usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import Client
|
||||||
|
|
||||||
|
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||||
|
|
||||||
|
client = Client(endpoint_url)
|
||||||
|
text = client.generate("Why is the sky blue?").generated_text
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
or with the asynchronous client:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
|
||||||
|
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
|
||||||
|
|
||||||
|
client = AsyncClient(endpoint_url)
|
||||||
|
response = await client.generate("Why is the sky blue?")
|
||||||
|
print(response.generated_text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
|
||||||
|
# Token Streaming
|
||||||
|
text = ""
|
||||||
|
async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
if not response.token.special:
|
||||||
|
text += response.token.text
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
# ' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Prompt tokens
|
||||||
|
class PrefillToken:
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
# Optional since the logprob of the first token cannot be computed
|
||||||
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
# Generated tokens
|
||||||
|
class Token:
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
logprob: float
|
||||||
|
# Is the token a special token
|
||||||
|
# Can be used to ignore tokens when concatenating
|
||||||
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Generation finish reason
|
||||||
|
class FinishReason(Enum):
|
||||||
|
# number of generated tokens == `max_new_tokens`
|
||||||
|
Length = "length"
|
||||||
|
# the model generated its end of sequence token
|
||||||
|
EndOfSequenceToken = "eos_token"
|
||||||
|
# the model generated a text included in `stop_sequences`
|
||||||
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence:
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` details
|
||||||
|
class Details:
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` return value
|
||||||
|
class Response:
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation details
|
||||||
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` details
|
||||||
|
class StreamDetails:
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` return value
|
||||||
|
class StreamResponse:
|
||||||
|
# Generated token
|
||||||
|
token: Token
|
||||||
|
# Complete generated text
|
||||||
|
# Only available when the generation is finished
|
||||||
|
generated_text: Optional[str]
|
||||||
|
# Generation details
|
||||||
|
# Only available when the generation is finished
|
||||||
|
details: Optional[StreamDetails]
|
||||||
|
```
|
1038
clients/python/poetry.lock
generated
Normal file
1038
clients/python/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
clients/python/pyproject.toml
Normal file
26
clients/python/pyproject.toml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "text-generation"
|
||||||
|
version = "0.3.1"
|
||||||
|
description = "Hugging Face Text Generation Python Client"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
readme = "README.md"
|
||||||
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
repository = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.7"
|
||||||
|
pydantic = "^1.10"
|
||||||
|
aiohttp = "^3.8"
|
||||||
|
huggingface-hub = ">= 0.12, < 1.0"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
pytest = "^6.2.5"
|
||||||
|
pytest-asyncio = "^0.17.2"
|
||||||
|
pytest-cov = "^3.0.0"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
51
clients/python/tests/conftest.py
Normal file
51
clients/python/tests/conftest.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import __version__
|
||||||
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl():
|
||||||
|
return "google/flan-t5-xxl"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_model():
|
||||||
|
return "fake/model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_model():
|
||||||
|
return "gpt2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_url():
|
||||||
|
return "https://api-inference.huggingface.co/models"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bloom_url(base_url, bloom_model):
|
||||||
|
return f"{base_url}/{bloom_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flan_t5_xxl_url(base_url, flan_t5_xxl):
|
||||||
|
return f"{base_url}/{flan_t5_xxl}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_url(base_url, fake_model):
|
||||||
|
return f"{base_url}/{fake_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_url(base_url, unsupported_model):
|
||||||
|
return f"{base_url}/{unsupported_model}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_headers():
|
||||||
|
return build_hf_headers(
|
||||||
|
library_name="text-generation-tests", library_version=__version__
|
||||||
|
)
|
133
clients/python/tests/test_client.py
Normal file
133
clients/python/tests/test_client.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import Client, AsyncClient
|
||||||
|
from text_generation.errors import NotFoundError, ValidationError
|
||||||
|
from text_generation.types import FinishReason, PrefillToken, Token
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
response = client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
|
assert response.generated_text == ""
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
assert len(response.details.prefill) == 1
|
||||||
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
|
assert len(response.details.tokens) == 1
|
||||||
|
assert response.details.tokens[0] == Token(
|
||||||
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
|
||||||
|
|
||||||
|
assert response.details.seed is not None
|
||||||
|
assert response.details.best_of_sequences is not None
|
||||||
|
assert len(response.details.best_of_sequences) == 1
|
||||||
|
assert response.details.best_of_sequences[0].seed is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_not_found(fake_url, hf_headers):
|
||||||
|
client = Client(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
client.generate("test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
responses = [
|
||||||
|
response for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(responses) == 1
|
||||||
|
response = responses[0]
|
||||||
|
|
||||||
|
assert response.generated_text == ""
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream_not_found(fake_url, hf_headers):
|
||||||
|
client = Client(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
list(client.generate_stream("test"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = Client(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
list(client.generate_stream("test", max_new_tokens=10_000))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
response = await client.generate("test", max_new_tokens=1)
|
||||||
|
|
||||||
|
assert response.generated_text == ""
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
assert len(response.details.prefill) == 1
|
||||||
|
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
|
||||||
|
assert len(response.details.tokens) == 1
|
||||||
|
assert response.details.tokens[0] == Token(
|
||||||
|
id=3, text=" ", logprob=-1.984375, special=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async_not_found(fake_url, hf_headers):
|
||||||
|
client = AsyncClient(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await client.generate("test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await client.generate("test", max_new_tokens=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
responses = [
|
||||||
|
response async for response in client.generate_stream("test", max_new_tokens=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(responses) == 1
|
||||||
|
response = responses[0]
|
||||||
|
|
||||||
|
assert response.generated_text == ""
|
||||||
|
assert response.details.finish_reason == FinishReason.Length
|
||||||
|
assert response.details.generated_tokens == 1
|
||||||
|
assert response.details.seed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async_not_found(fake_url, hf_headers):
|
||||||
|
client = AsyncClient(fake_url, hf_headers)
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
async for _ in client.generate_stream("test"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
|
||||||
|
client = AsyncClient(flan_t5_xxl_url, hf_headers)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
async for _ in client.generate_stream("test", max_new_tokens=10_000):
|
||||||
|
pass
|
64
clients/python/tests/test_errors.py
Normal file
64
clients/python/tests/test_errors.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from text_generation.errors import (
|
||||||
|
parse_error,
|
||||||
|
GenerationError,
|
||||||
|
IncompleteGenerationError,
|
||||||
|
OverloadedError,
|
||||||
|
ValidationError,
|
||||||
|
BadRequestError,
|
||||||
|
ShardNotReadyError,
|
||||||
|
ShardTimeoutError,
|
||||||
|
NotFoundError,
|
||||||
|
RateLimitExceededError,
|
||||||
|
UnknownError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generation_error():
|
||||||
|
payload = {"error_type": "generation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), GenerationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_incomplete_generation_error():
|
||||||
|
payload = {"error_type": "incomplete_generation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_overloaded_error():
|
||||||
|
payload = {"error_type": "overloaded", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), OverloadedError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error():
|
||||||
|
payload = {"error_type": "validation", "error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), ValidationError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_request_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(400, payload), BadRequestError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_not_ready_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(403, payload), ShardNotReadyError)
|
||||||
|
assert isinstance(parse_error(424, payload), ShardNotReadyError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_timeout_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(504, payload), ShardTimeoutError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_not_found_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(404, payload), NotFoundError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_exceeded_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(429, payload), RateLimitExceededError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_error():
|
||||||
|
payload = {"error": "test"}
|
||||||
|
assert isinstance(parse_error(500, payload), UnknownError)
|
34
clients/python/tests/test_inference_api.py
Normal file
34
clients/python/tests/test_inference_api.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation import (
|
||||||
|
InferenceAPIClient,
|
||||||
|
InferenceAPIAsyncClient,
|
||||||
|
Client,
|
||||||
|
AsyncClient,
|
||||||
|
)
|
||||||
|
from text_generation.errors import NotSupportedError
|
||||||
|
from text_generation.inference_api import get_supported_models
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_supported_models():
|
||||||
|
assert isinstance(get_supported_models(), list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client(flan_t5_xxl):
|
||||||
|
client = InferenceAPIClient(flan_t5_xxl)
|
||||||
|
assert isinstance(client, Client)
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_unsupported_model(unsupported_model):
|
||||||
|
with pytest.raises(NotSupportedError):
|
||||||
|
InferenceAPIClient(unsupported_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client(flan_t5_xxl):
|
||||||
|
client = InferenceAPIAsyncClient(flan_t5_xxl)
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_client_unsupported_model(unsupported_model):
|
||||||
|
with pytest.raises(NotSupportedError):
|
||||||
|
InferenceAPIAsyncClient(unsupported_model)
|
82
clients/python/tests/test_types.py
Normal file
82
clients/python/tests/test_types.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation.types import Parameters, Request
|
||||||
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_validation():
|
||||||
|
# Test best_of
|
||||||
|
Parameters(best_of=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=-1)
|
||||||
|
Parameters(best_of=2, do_sample=True)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(best_of=2)
|
||||||
|
|
||||||
|
# Test repetition_penalty
|
||||||
|
Parameters(repetition_penalty=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(repetition_penalty=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(repetition_penalty=-1)
|
||||||
|
|
||||||
|
# Test seed
|
||||||
|
Parameters(seed=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(seed=-1)
|
||||||
|
|
||||||
|
# Test temperature
|
||||||
|
Parameters(temperature=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(temperature=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(temperature=-1)
|
||||||
|
|
||||||
|
# Test top_k
|
||||||
|
Parameters(top_k=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_k=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_k=-1)
|
||||||
|
|
||||||
|
# Test top_p
|
||||||
|
Parameters(top_p=0.5)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=-1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(top_p=1)
|
||||||
|
|
||||||
|
# Test truncate
|
||||||
|
Parameters(truncate=1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(truncate=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(truncate=-1)
|
||||||
|
|
||||||
|
# Test typical_p
|
||||||
|
Parameters(typical_p=0.5)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=-1)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Parameters(typical_p=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_validation():
|
||||||
|
Request(inputs="test")
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Request(inputs="")
|
||||||
|
|
||||||
|
Request(inputs="test", stream=True)
|
||||||
|
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Request(
|
||||||
|
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
|
||||||
|
)
|
18
clients/python/text_generation/__init__.py
Normal file
18
clients/python/text_generation/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__version__ = "0.3.0"
|
||||||
|
|
||||||
|
from text_generation.client import Client, AsyncClient
|
||||||
|
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
|
487
clients/python/text_generation/client.py
Normal file
487
clients/python/text_generation/client.py
Normal file
@ -0,0 +1,487 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||||
|
|
||||||
|
from text_generation.types import (
|
||||||
|
StreamResponse,
|
||||||
|
Response,
|
||||||
|
Request,
|
||||||
|
Parameters,
|
||||||
|
)
|
||||||
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
"""Client to make calls to a text-generation-inference instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import Client
|
||||||
|
|
||||||
|
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||||
|
>>> client.generate("Why is the sky blue?").generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
base_url (`str`):
|
||||||
|
text-generation-inference instance base url
|
||||||
|
headers (`Optional[Dict[str, str]]`):
|
||||||
|
Additional headers
|
||||||
|
cookies (`Optional[Dict[str, str]]`):
|
||||||
|
Cookies to include in the requests
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.cookies = cookies
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
watermark (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: generated response
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Iterator[StreamResponse]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following stream of tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
watermark (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[StreamResponse]: stream of generated tokens
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, resp.json())
|
||||||
|
|
||||||
|
# Parse ServerSentEvents
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
# Skip line
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
# Event data
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
# Decode payload
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
# Parse payload
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
# If we failed to parse the payload, then it is an error payload
|
||||||
|
raise parse_error(resp.status_code, json_payload)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncClient:
|
||||||
|
"""Asynchronous Client to make calls to a text-generation-inference instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import AsyncClient
|
||||||
|
|
||||||
|
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
|
||||||
|
>>> response = await client.generate("Why is the sky blue?")
|
||||||
|
>>> response.generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
base_url (`str`):
|
||||||
|
text-generation-inference instance base url
|
||||||
|
headers (`Optional[Dict[str, str]]`):
|
||||||
|
Additional headers
|
||||||
|
cookies (`Optional[Dict[str, str]]`):
|
||||||
|
Cookies to include in the requests
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.cookies = cookies
|
||||||
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following text asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
watermark (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: generated response
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
truncate: Optional[int] = None,
|
||||||
|
typical_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> AsyncIterator[StreamResponse]:
|
||||||
|
"""
|
||||||
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str`):
|
||||||
|
Input text
|
||||||
|
do_sample (`bool`):
|
||||||
|
Activate logits sampling
|
||||||
|
max_new_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
best_of (`int`):
|
||||||
|
Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
repetition_penalty (`float`):
|
||||||
|
The parameter for repetition penalty. 1.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
return_full_text (`bool`):
|
||||||
|
Whether to prepend the prompt to the generated text
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
stop_sequences (`List[str]`):
|
||||||
|
Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_k (`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
truncate (`int`):
|
||||||
|
Truncate inputs tokens to the given size
|
||||||
|
typical_p (`float`):
|
||||||
|
Typical Decoding mass
|
||||||
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
watermark (`bool`):
|
||||||
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
|
"""
|
||||||
|
# Validate parameters
|
||||||
|
parameters = Parameters(
|
||||||
|
best_of=best_of,
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop_sequences if stop_sequences is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
truncate=truncate,
|
||||||
|
typical_p=typical_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, await resp.json())
|
||||||
|
|
||||||
|
# Parse ServerSentEvents
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
# Skip line
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
# Event data
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
# Decode payload
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
# Parse payload
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
# If we failed to parse the payload, then it is an error payload
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
yield response
|
106
clients/python/text_generation/errors.py
Normal file
106
clients/python/text_generation/errors.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
# Text Generation Inference Errors
|
||||||
|
class ValidationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class OverloadedError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteGenerationError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# API Inference Errors
|
||||||
|
class BadRequestError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardNotReadyError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardTimeoutError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceededError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedError(Exception):
|
||||||
|
def __init__(self, model_id: str):
|
||||||
|
message = (
|
||||||
|
f"Model `{model_id}` is not available for inference with this client. \n"
|
||||||
|
"Use `huggingface_hub.inference_api.InferenceApi` instead."
|
||||||
|
)
|
||||||
|
super(NotSupportedError, self).__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
# Unknown error
|
||||||
|
class UnknownError(Exception):
|
||||||
|
def __init__(self, message: str):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
|
||||||
|
"""
|
||||||
|
Parse error given an HTTP status code and a json payload
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status_code (`int`):
|
||||||
|
HTTP status code
|
||||||
|
payload (`Dict[str, str]`):
|
||||||
|
Json payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception: parsed exception
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Try to parse a Text Generation Inference error
|
||||||
|
message = payload["error"]
|
||||||
|
if "error_type" in payload:
|
||||||
|
error_type = payload["error_type"]
|
||||||
|
if error_type == "generation":
|
||||||
|
return GenerationError(message)
|
||||||
|
if error_type == "incomplete_generation":
|
||||||
|
return IncompleteGenerationError(message)
|
||||||
|
if error_type == "overloaded":
|
||||||
|
return OverloadedError(message)
|
||||||
|
if error_type == "validation":
|
||||||
|
return ValidationError(message)
|
||||||
|
|
||||||
|
# Try to parse a APIInference error
|
||||||
|
if status_code == 400:
|
||||||
|
return BadRequestError(message)
|
||||||
|
if status_code == 403 or status_code == 424:
|
||||||
|
return ShardNotReadyError(message)
|
||||||
|
if status_code == 504:
|
||||||
|
return ShardTimeoutError(message)
|
||||||
|
if status_code == 404:
|
||||||
|
return NotFoundError(message)
|
||||||
|
if status_code == 429:
|
||||||
|
return RateLimitExceededError(message)
|
||||||
|
|
||||||
|
# Fallback to an unknown error
|
||||||
|
return UnknownError(message)
|
154
clients/python/text_generation/inference_api.py
Normal file
154
clients/python/text_generation/inference_api.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from huggingface_hub.utils import build_hf_headers
|
||||||
|
|
||||||
|
from text_generation import Client, AsyncClient, __version__
|
||||||
|
from text_generation.errors import NotSupportedError
|
||||||
|
|
||||||
|
INFERENCE_ENDPOINT = os.environ.get(
|
||||||
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_models() -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Get the list of supported text-generation models from GitHub
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
|
||||||
|
"""
|
||||||
|
global SUPPORTED_MODELS
|
||||||
|
if SUPPORTED_MODELS is not None:
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
file_content = response.json()["content"]
|
||||||
|
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
warnings.warn("Could not retrieve list of supported models.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceAPIClient(Client):
|
||||||
|
"""Client to make calls to the HuggingFace Inference API.
|
||||||
|
|
||||||
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||||
|
text-generation-inference
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import InferenceAPIClient
|
||||||
|
|
||||||
|
>>> client = InferenceAPIClient("bigscience/bloomz")
|
||||||
|
>>> client.generate("Why is the sky blue?").generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
"""
|
||||||
|
Init headers and API information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
Id of repository (e.g. `bigscience/bloom`).
|
||||||
|
token (`str`, `optional`):
|
||||||
|
The API token to use as HTTP bearer authorization. This is not
|
||||||
|
the authentication token. You can find the token in
|
||||||
|
https://huggingface.co/settings/token. Alternatively, you can
|
||||||
|
find both your organizations and personal API tokens using
|
||||||
|
`HfApi().whoami(token)`.
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and repo_id not in supported_models:
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
|
headers = build_hf_headers(
|
||||||
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
|
)
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
|
super(InferenceAPIClient, self).__init__(
|
||||||
|
base_url, headers=headers, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceAPIAsyncClient(AsyncClient):
|
||||||
|
"""Aynschronous Client to make calls to the HuggingFace Inference API.
|
||||||
|
|
||||||
|
Only supports a subset of the available text-generation or text2text-generation models that are served using
|
||||||
|
text-generation-inference
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from text_generation import InferenceAPIAsyncClient
|
||||||
|
|
||||||
|
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
|
||||||
|
>>> response = await client.generate("Why is the sky blue?")
|
||||||
|
>>> response.generated_text
|
||||||
|
' Rayleigh scattering'
|
||||||
|
|
||||||
|
>>> result = ""
|
||||||
|
>>> async for response in client.generate_stream("Why is the sky blue?"):
|
||||||
|
>>> if not response.token.special:
|
||||||
|
>>> result += response.token.text
|
||||||
|
>>> result
|
||||||
|
' Rayleigh scattering'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
"""
|
||||||
|
Init headers and API information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
Id of repository (e.g. `bigscience/bloom`).
|
||||||
|
token (`str`, `optional`):
|
||||||
|
The API token to use as HTTP bearer authorization. This is not
|
||||||
|
the authentication token. You can find the token in
|
||||||
|
https://huggingface.co/settings/token. Alternatively, you can
|
||||||
|
find both your organizations and personal API tokens using
|
||||||
|
`HfApi().whoami(token)`.
|
||||||
|
timeout (`int`):
|
||||||
|
Timeout in seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and repo_id not in supported_models:
|
||||||
|
raise NotSupportedError(repo_id)
|
||||||
|
|
||||||
|
headers = build_hf_headers(
|
||||||
|
token=token, library_name="text-generation", library_version=__version__
|
||||||
|
)
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
|
||||||
|
|
||||||
|
super(InferenceAPIAsyncClient, self).__init__(
|
||||||
|
base_url, headers=headers, timeout=timeout
|
||||||
|
)
|
223
clients/python/text_generation/types.py
Normal file
223
clients/python/text_generation/types.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class Parameters(BaseModel):
|
||||||
|
# Activate logits sampling
|
||||||
|
do_sample: bool = False
|
||||||
|
# Maximum number of generated tokens
|
||||||
|
max_new_tokens: int = 20
|
||||||
|
# The parameter for repetition penalty. 1.0 means no penalty.
|
||||||
|
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
repetition_penalty: Optional[float] = None
|
||||||
|
# Whether to prepend the prompt to the generated text
|
||||||
|
return_full_text: bool = False
|
||||||
|
# Stop generating tokens if a member of `stop_sequences` is generated
|
||||||
|
stop: List[str] = []
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int]
|
||||||
|
# The value used to module the logits distribution.
|
||||||
|
temperature: Optional[float]
|
||||||
|
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_k: Optional[int]
|
||||||
|
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
# higher are kept for generation.
|
||||||
|
top_p: Optional[float]
|
||||||
|
# truncate inputs tokens to the given size
|
||||||
|
truncate: Optional[int]
|
||||||
|
# Typical Decoding mass
|
||||||
|
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
|
typical_p: Optional[float]
|
||||||
|
# Generate best_of sequences and return the one if the highest token logprobs
|
||||||
|
best_of: Optional[int]
|
||||||
|
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
watermark: bool = False
|
||||||
|
# Get generation details
|
||||||
|
details: bool = False
|
||||||
|
|
||||||
|
@validator("best_of")
|
||||||
|
def valid_best_of(cls, field_value, values):
|
||||||
|
if field_value is not None:
|
||||||
|
if field_value <= 0:
|
||||||
|
raise ValidationError("`best_of` must be strictly positive")
|
||||||
|
sampling = (
|
||||||
|
values["do_sample"]
|
||||||
|
| (values["temperature"] is not None)
|
||||||
|
| (values["top_k"] is not None)
|
||||||
|
| (values["top_p"] is not None)
|
||||||
|
| (values["typical_p"] is not None)
|
||||||
|
)
|
||||||
|
if field_value > 1 and not sampling:
|
||||||
|
raise ValidationError("you must use sampling when `best_of` is > 1")
|
||||||
|
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
@validator("repetition_penalty")
|
||||||
|
def valid_repetition_penalty(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`repetition_penalty` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("seed")
|
||||||
|
def valid_seed(cls, v):
|
||||||
|
if v is not None and v < 0:
|
||||||
|
raise ValidationError("`seed` must be positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("temperature")
|
||||||
|
def valid_temp(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`temperature` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("top_k")
|
||||||
|
def valid_top_k(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`top_k` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("top_p")
|
||||||
|
def valid_top_p(cls, v):
|
||||||
|
if v is not None and (v <= 0 or v >= 1.0):
|
||||||
|
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("truncate")
|
||||||
|
def valid_truncate(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`truncate` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("typical_p")
|
||||||
|
def valid_typical_p(cls, v):
|
||||||
|
if v is not None and (v <= 0 or v >= 1.0):
|
||||||
|
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class Request(BaseModel):
|
||||||
|
# Prompt
|
||||||
|
inputs: str
|
||||||
|
# Generation parameters
|
||||||
|
parameters: Optional[Parameters]
|
||||||
|
# Whether to stream output tokens
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
@validator("inputs")
|
||||||
|
def valid_input(cls, v):
|
||||||
|
if not v:
|
||||||
|
raise ValidationError("`inputs` cannot be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("stream")
|
||||||
|
def valid_best_of_stream(cls, field_value, values):
|
||||||
|
parameters = values["parameters"]
|
||||||
|
if (
|
||||||
|
parameters is not None
|
||||||
|
and parameters.best_of is not None
|
||||||
|
and parameters.best_of > 1
|
||||||
|
and field_value
|
||||||
|
):
|
||||||
|
raise ValidationError(
|
||||||
|
"`best_of` != 1 is not supported when `stream` == True"
|
||||||
|
)
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
|
# Prompt tokens
|
||||||
|
class PrefillToken(BaseModel):
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
# Optional since the logprob of the first token cannot be computed
|
||||||
|
logprob: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
# Generated tokens
|
||||||
|
class Token(BaseModel):
|
||||||
|
# Token ID from the model tokenizer
|
||||||
|
id: int
|
||||||
|
# Token text
|
||||||
|
text: str
|
||||||
|
# Logprob
|
||||||
|
logprob: float
|
||||||
|
# Is the token a special token
|
||||||
|
# Can be used to ignore tokens when concatenating
|
||||||
|
special: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Generation finish reason
|
||||||
|
class FinishReason(Enum):
|
||||||
|
# number of generated tokens == `max_new_tokens`
|
||||||
|
Length = "length"
|
||||||
|
# the model generated its end of sequence token
|
||||||
|
EndOfSequenceToken = "eos_token"
|
||||||
|
# the model generated a text included in `stop_sequences`
|
||||||
|
StopSequence = "stop_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
class BestOfSequence(BaseModel):
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` details
|
||||||
|
class Details(BaseModel):
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
# Prompt tokens
|
||||||
|
prefill: List[PrefillToken]
|
||||||
|
# Generated tokens
|
||||||
|
tokens: List[Token]
|
||||||
|
# Additional sequences when using the `best_of` parameter
|
||||||
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate` return value
|
||||||
|
class Response(BaseModel):
|
||||||
|
# Generated text
|
||||||
|
generated_text: str
|
||||||
|
# Generation details
|
||||||
|
details: Details
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` details
|
||||||
|
class StreamDetails(BaseModel):
|
||||||
|
# Generation finish reason
|
||||||
|
finish_reason: FinishReason
|
||||||
|
# Number of generated tokens
|
||||||
|
generated_tokens: int
|
||||||
|
# Sampling seed if sampling was activated
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
# `generate_stream` return value
|
||||||
|
class StreamResponse(BaseModel):
|
||||||
|
# Generated token
|
||||||
|
token: Token
|
||||||
|
# Complete generated text
|
||||||
|
# Only available when the generation is finished
|
||||||
|
generated_text: Optional[str]
|
||||||
|
# Generation details
|
||||||
|
# Only available when the generation is finished
|
||||||
|
details: Optional[StreamDetails]
|
@ -11,7 +11,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.2.1"
|
"version": "0.4.0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/generate": {
|
"/generate": {
|
||||||
@ -38,23 +38,17 @@
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/GenerateResponse"
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Input validation error",
|
"description": "Input validation error",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Input validation error"
|
"error": "Input validation error"
|
||||||
@ -67,10 +61,7 @@
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Request failed during generation"
|
"error": "Request failed during generation"
|
||||||
@ -83,10 +74,7 @@
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Model is overloaded"
|
"error": "Model is overloaded"
|
||||||
@ -99,10 +87,7 @@
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Incomplete generation"
|
"error": "Incomplete generation"
|
||||||
@ -136,25 +121,19 @@
|
|||||||
"200": {
|
"200": {
|
||||||
"description": "Generated Text",
|
"description": "Generated Text",
|
||||||
"content": {
|
"content": {
|
||||||
"text/event-stream ": {
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/StreamResponse"
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Input validation error",
|
"description": "Input validation error",
|
||||||
"content": {
|
"content": {
|
||||||
"text/event-stream ": {
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Input validation error"
|
"error": "Input validation error"
|
||||||
@ -165,12 +144,9 @@
|
|||||||
"424": {
|
"424": {
|
||||||
"description": "Generation Error",
|
"description": "Generation Error",
|
||||||
"content": {
|
"content": {
|
||||||
"text/event-stream ": {
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Request failed during generation"
|
"error": "Request failed during generation"
|
||||||
@ -181,12 +157,9 @@
|
|||||||
"429": {
|
"429": {
|
||||||
"description": "Model is overloaded",
|
"description": "Model is overloaded",
|
||||||
"content": {
|
"content": {
|
||||||
"text/event-stream ": {
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Model is overloaded"
|
"error": "Model is overloaded"
|
||||||
@ -197,12 +170,9 @@
|
|||||||
"500": {
|
"500": {
|
||||||
"description": "Incomplete generation",
|
"description": "Incomplete generation",
|
||||||
"content": {
|
"content": {
|
||||||
"text/event-stream ": {
|
"text/event-stream": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/ErrorResponse"
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"example": {
|
"example": {
|
||||||
"error": "Incomplete generation"
|
"error": "Incomplete generation"
|
||||||
@ -213,17 +183,90 @@
|
|||||||
},
|
},
|
||||||
"deprecated": false
|
"deprecated": false
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"/metrics": {
|
||||||
|
"get": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Prometheus metrics scrape endpoint",
|
||||||
|
"description": "Prometheus metrics scrape endpoint",
|
||||||
|
"operationId": "metrics",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Prometheus Metrics",
|
||||||
|
"content": {
|
||||||
|
"text/plain": {
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
|
"BestOfSequence": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"generated_text",
|
||||||
|
"finish_reason",
|
||||||
|
"generated_tokens",
|
||||||
|
"prefill",
|
||||||
|
"tokens"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"finish_reason": {
|
||||||
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
|
},
|
||||||
|
"generated_text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
},
|
||||||
|
"generated_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 1
|
||||||
|
},
|
||||||
|
"prefill": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/PrefillToken"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"tokens": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"Details": {
|
"Details": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"finish_reason",
|
"finish_reason",
|
||||||
"generated_tokens"
|
"generated_tokens",
|
||||||
|
"prefill",
|
||||||
|
"tokens"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"best_of_sequences": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/BestOfSequence"
|
||||||
|
}
|
||||||
|
},
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"$ref": "#/components/schemas/FinishReason"
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
},
|
},
|
||||||
@ -235,13 +278,14 @@
|
|||||||
"prefill": {
|
"prefill": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"$ref": "#/components/schemas/Token"
|
"$ref": "#/components/schemas/PrefillToken"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tokens": {
|
"tokens": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
@ -254,11 +298,15 @@
|
|||||||
"ErrorResponse": {
|
"ErrorResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"error"
|
"error",
|
||||||
|
"error_type"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"error": {
|
"error": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
},
|
||||||
|
"error_type": {
|
||||||
|
"type": "string"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -273,6 +321,13 @@
|
|||||||
"GenerateParameters": {
|
"GenerateParameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"best_of": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": "null",
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
"details": {
|
"details": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": "true"
|
"default": "true"
|
||||||
@ -297,9 +352,19 @@
|
|||||||
"nullable": true,
|
"nullable": true,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
|
"return_full_text": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": "null",
|
||||||
|
"example": false,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64"
|
"format": "int64",
|
||||||
|
"default": "null",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
"stop": {
|
"stop": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
@ -335,6 +400,26 @@
|
|||||||
"nullable": true,
|
"nullable": true,
|
||||||
"maximum": 1.0,
|
"maximum": 1.0,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"truncate": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": "null",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"typical_p": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"default": "null",
|
||||||
|
"example": 0.95,
|
||||||
|
"nullable": true,
|
||||||
|
"maximum": 1.0,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"watermark": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": "false",
|
||||||
|
"example": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -368,6 +453,31 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"PrefillToken": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"text",
|
||||||
|
"logprob"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 0
|
||||||
|
},
|
||||||
|
"logprob": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"example": -0.34,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"StreamDetails": {
|
"StreamDetails": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@ -386,7 +496,8 @@
|
|||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -415,7 +526,8 @@
|
|||||||
"required": [
|
"required": [
|
||||||
"id",
|
"id",
|
||||||
"text",
|
"text",
|
||||||
"logprob"
|
"logprob",
|
||||||
|
"special"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": {
|
"id": {
|
||||||
@ -429,6 +541,10 @@
|
|||||||
"example": -0.34,
|
"example": -0.34,
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"special": {
|
||||||
|
"type": "boolean",
|
||||||
|
"example": "false"
|
||||||
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"example": "test"
|
"example": "test"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
description = "Text Generation Launcher"
|
description = "Text Generation Launcher"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Read};
|
use std::io::{BufRead, BufReader, Read};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
@ -12,7 +13,7 @@ use std::thread;
|
|||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{fs, io};
|
||||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -23,13 +24,21 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
sharded: Option<bool>,
|
||||||
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
#[clap(default_value = "1512", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
@ -43,38 +52,112 @@ struct Args {
|
|||||||
#[clap(default_value = "29500", long, env)]
|
#[clap(default_value = "29500", long, env)]
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
disable_custom_kernels: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Vec<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
if args.json_output {
|
||||||
|
tracing_subscriber::fmt().json().init();
|
||||||
|
} else {
|
||||||
|
tracing_subscriber::fmt().compact().init();
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("{:?}", args);
|
||||||
|
|
||||||
let Args {
|
let Args {
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
sharded,
|
||||||
num_shard,
|
num_shard,
|
||||||
quantize,
|
quantize,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
} = Args::parse();
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
|
} = args;
|
||||||
|
|
||||||
if json_output {
|
// get the number of shards given `sharded` and `num_shard`
|
||||||
tracing_subscriber::fmt().json().init();
|
let num_shard = if let Some(sharded) = sharded {
|
||||||
|
// sharded is set
|
||||||
|
match sharded {
|
||||||
|
// sharded is set and true
|
||||||
|
true => {
|
||||||
|
match num_shard {
|
||||||
|
None => {
|
||||||
|
// try to default to the number of available GPUs
|
||||||
|
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||||
|
let n_devices = num_cuda_devices()
|
||||||
|
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||||
|
if n_devices <= 1 {
|
||||||
|
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||||
|
}
|
||||||
|
n_devices
|
||||||
|
}
|
||||||
|
Some(num_shard) => {
|
||||||
|
// we can't have only one shard while sharded
|
||||||
|
if num_shard <= 1 {
|
||||||
|
panic!("`sharded` is true but `num_shard` <= 1");
|
||||||
|
}
|
||||||
|
num_shard
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// sharded is set and false
|
||||||
|
false => {
|
||||||
|
let num_shard = num_shard.unwrap_or(1);
|
||||||
|
// we can't have more than one shard while not sharded
|
||||||
|
if num_shard != 1 {
|
||||||
|
panic!("`sharded` is false but `num_shard` != 1");
|
||||||
|
}
|
||||||
|
num_shard
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tracing_subscriber::fmt().compact().init();
|
match num_shard {
|
||||||
|
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
|
||||||
|
None => num_cuda_devices().unwrap_or(1),
|
||||||
|
Some(num_shard) => num_shard,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if num_shard < 1 {
|
||||||
|
panic!("`num_shard` cannot be < 1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// By default we only have one master shard
|
if num_shard > 1 {
|
||||||
let num_shard = num_shard.unwrap_or(1);
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
|
}
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
@ -84,6 +167,121 @@ fn main() -> ExitCode {
|
|||||||
})
|
})
|
||||||
.expect("Error setting Ctrl-C handler");
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Check if model_id is a local model
|
||||||
|
let local_path = Path::new(&model_id);
|
||||||
|
let is_local_model = local_path.exists() && local_path.is_dir();
|
||||||
|
|
||||||
|
// Download weights for sharded models
|
||||||
|
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
|
||||||
|
let mut download_argv = vec![
|
||||||
|
"text-generation-server".to_string(),
|
||||||
|
"download-weights".to_string(),
|
||||||
|
model_id.clone(),
|
||||||
|
"--extension".to_string(),
|
||||||
|
".safetensors".to_string(),
|
||||||
|
"--logger-level".to_string(),
|
||||||
|
"INFO".to_string(),
|
||||||
|
"--json-output".to_string(),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Model optional revision
|
||||||
|
if let Some(ref revision) = revision {
|
||||||
|
download_argv.push("--revision".to_string());
|
||||||
|
download_argv.push(revision.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy current process env
|
||||||
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// If huggingface_hub_cache is set, pass it to the shard
|
||||||
|
// Useful when running inside a docker container
|
||||||
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Enable hf transfer for insane download speeds
|
||||||
|
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
|
||||||
|
|
||||||
|
// Start process
|
||||||
|
tracing::info!("Starting download process.");
|
||||||
|
let mut download_process = match Popen::create(
|
||||||
|
&download_argv,
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
// Needed for the shutdown procedure
|
||||||
|
setpgid: true,
|
||||||
|
env: Some(env),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(err) => {
|
||||||
|
if let PopenError::IoError(ref err) = err {
|
||||||
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
|
tracing::error!("Please install it with `make install-server`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Redirect STDOUT to the console
|
||||||
|
let download_stdout = download_process.stdout.take().unwrap();
|
||||||
|
thread::spawn(move || {
|
||||||
|
// Enter download tracing span
|
||||||
|
let stdout = BufReader::new(download_stdout);
|
||||||
|
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||||
|
for line in stdout.lines() {
|
||||||
|
// Parse loguru logs
|
||||||
|
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
|
||||||
|
if let Some(text) = value.get("text") {
|
||||||
|
// Format escaped newlines
|
||||||
|
tracing::info!("{}", text.to_string().replace("\\n", ""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if let Some(status) = download_process.poll() {
|
||||||
|
match status {
|
||||||
|
ExitStatus::Exited(exit_code) => {
|
||||||
|
if exit_code == 0 {
|
||||||
|
tracing::info!("Successfully downloaded weights.");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let mut err = String::new();
|
||||||
|
download_process
|
||||||
|
.stderr
|
||||||
|
.take()
|
||||||
|
.unwrap()
|
||||||
|
.read_to_string(&mut err)
|
||||||
|
.unwrap();
|
||||||
|
tracing::error!("Download encountered an error: {err}");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::error!("Download process exited with an unknown status.");
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !running.load(Ordering::SeqCst) {
|
||||||
|
download_process.terminate().unwrap();
|
||||||
|
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||||
|
download_process
|
||||||
|
.wait_timeout(Duration::from_secs(90))
|
||||||
|
.unwrap();
|
||||||
|
tracing::info!("Download process terminated");
|
||||||
|
return ExitCode::SUCCESS;
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Shared shutdown bool
|
// Shared shutdown bool
|
||||||
let shutdown = Arc::new(Mutex::new(false));
|
let shutdown = Arc::new(Mutex::new(false));
|
||||||
// Shared shutdown channel
|
// Shared shutdown channel
|
||||||
@ -99,6 +297,8 @@ fn main() -> ExitCode {
|
|||||||
let revision = revision.clone();
|
let revision = revision.clone();
|
||||||
let uds_path = shard_uds_path.clone();
|
let uds_path = shard_uds_path.clone();
|
||||||
let master_addr = master_addr.clone();
|
let master_addr = master_addr.clone();
|
||||||
|
let huggingface_hub_cache = huggingface_hub_cache.clone();
|
||||||
|
let weights_cache_override = weights_cache_override.clone();
|
||||||
let status_sender = status_sender.clone();
|
let status_sender = status_sender.clone();
|
||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
@ -113,6 +313,11 @@ fn main() -> ExitCode {
|
|||||||
num_shard,
|
num_shard,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
huggingface_hub_cache,
|
||||||
|
weights_cache_override,
|
||||||
|
disable_custom_kernels,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -161,8 +366,14 @@ fn main() -> ExitCode {
|
|||||||
"text-generation-router".to_string(),
|
"text-generation-router".to_string(),
|
||||||
"--max-concurrent-requests".to_string(),
|
"--max-concurrent-requests".to_string(),
|
||||||
max_concurrent_requests.to_string(),
|
max_concurrent_requests.to_string(),
|
||||||
|
"--max-best-of".to_string(),
|
||||||
|
max_best_of.to_string(),
|
||||||
|
"--max-stop-sequences".to_string(),
|
||||||
|
max_stop_sequences.to_string(),
|
||||||
"--max-input-length".to_string(),
|
"--max-input-length".to_string(),
|
||||||
max_input_length.to_string(),
|
max_input_length.to_string(),
|
||||||
|
"--max-total-tokens".to_string(),
|
||||||
|
max_total_tokens.to_string(),
|
||||||
"--max-batch-size".to_string(),
|
"--max-batch-size".to_string(),
|
||||||
max_batch_size.to_string(),
|
max_batch_size.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
@ -185,6 +396,12 @@ fn main() -> ExitCode {
|
|||||||
argv.push(otlp_endpoint);
|
argv.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CORS origins
|
||||||
|
for origin in cors_allow_origin.into_iter() {
|
||||||
|
argv.push("--cors-allow-origin".to_string());
|
||||||
|
argv.push(origin);
|
||||||
|
}
|
||||||
|
|
||||||
let mut webserver = match Popen::create(
|
let mut webserver = match Popen::create(
|
||||||
&argv,
|
&argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
@ -232,7 +449,7 @@ fn main() -> ExitCode {
|
|||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {} failed:\n{}", rank, err);
|
tracing::error!("Shard {rank} failed:\n{err}");
|
||||||
exit_code = ExitCode::FAILURE;
|
exit_code = ExitCode::FAILURE;
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
@ -275,6 +492,11 @@ fn shard_manager(
|
|||||||
world_size: usize,
|
world_size: usize,
|
||||||
master_addr: String,
|
master_addr: String,
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
|
huggingface_hub_cache: Option<String>,
|
||||||
|
weights_cache_override: Option<String>,
|
||||||
|
disable_custom_kernels: bool,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
@ -319,43 +541,54 @@ fn shard_manager(
|
|||||||
shard_argv.push(otlp_endpoint);
|
shard_argv.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut env = vec![
|
// Copy current process env
|
||||||
("RANK".into(), rank.to_string().into()),
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
("WORLD_SIZE".into(), world_size.to_string().into()),
|
|
||||||
("MASTER_ADDR".into(), master_addr.into()),
|
|
||||||
("MASTER_PORT".into(), master_port.to_string().into()),
|
|
||||||
("SAFETENSORS_FAST_GPU".into(), "1".into()),
|
|
||||||
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
|
||||||
];
|
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// Torch Distributed Env vars
|
||||||
|
env.push(("RANK".into(), rank.to_string().into()));
|
||||||
|
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
|
env.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||||
|
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
|
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||||
|
|
||||||
|
// Safetensors load fast
|
||||||
|
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||||
|
|
||||||
|
// Enable hf transfer for insane download speeds
|
||||||
|
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
|
||||||
|
|
||||||
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
|
// If weights_cache_override is some, pass it to the shard
|
||||||
// Useful when running inside a HuggingFace Inference Endpoint
|
// Useful when running inside a HuggingFace Inference Endpoint
|
||||||
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
|
if let Some(weights_cache_override) = weights_cache_override {
|
||||||
env.push((
|
env.push((
|
||||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||||
weights_cache_override.into(),
|
weights_cache_override.into(),
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
// If the NCCL_SHM_DISABLE env var is set, pass it to the shard
|
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||||
// needed when running NCCL inside a docker container and when you can't increase shm size
|
if disable_custom_kernels {
|
||||||
if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") {
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
env.push(("NCCL_SHM_DISABLE".into(), nccl_shm_disalbe.into()));
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
// Watermark Gamma
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
if let Some(watermark_gamma) = watermark_gamma {
|
||||||
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||||
};
|
}
|
||||||
|
|
||||||
|
// Watermark Delta
|
||||||
|
if let Some(watermark_delta) = watermark_delta {
|
||||||
|
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {}", rank);
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
&shard_argv,
|
&shard_argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
@ -419,17 +652,17 @@ fn shard_manager(
|
|||||||
if *shutdown.lock().unwrap() {
|
if *shutdown.lock().unwrap() {
|
||||||
p.terminate().unwrap();
|
p.terminate().unwrap();
|
||||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
tracing::info!("Shard {} terminated", rank);
|
tracing::info!("Shard {rank} terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shard is ready
|
// Shard is ready
|
||||||
if uds.exists() && !ready {
|
if uds.exists() && !ready {
|
||||||
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
tracing::info!("Waiting for shard {} to be ready...", rank);
|
tracing::info!("Waiting for shard {rank} to be ready...");
|
||||||
wait_time = Instant::now();
|
wait_time = Instant::now();
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
@ -449,3 +682,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
|
|||||||
// This will block till all shutdown_sender are dropped
|
// This will block till all shutdown_sender are dropped
|
||||||
let _ = shutdown_receiver.recv();
|
let _ = shutdown_receiver.recv();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn num_cuda_devices() -> Option<usize> {
|
||||||
|
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
|
let n_devices = cuda_visible_devices.split(',').count();
|
||||||
|
return Some(n_devices);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
@ -1,122 +1,142 @@
|
|||||||
{
|
{
|
||||||
|
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
|
"seed": null,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 10264,
|
"id": 10264,
|
||||||
"logprob": null,
|
"text": "Test",
|
||||||
"text": "Test"
|
"logprob": null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 8821,
|
"id": 8821,
|
||||||
"logprob": -11.894989,
|
"text": " request",
|
||||||
"text": " request"
|
"logprob": -11.894989
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 17,
|
"id": 17,
|
||||||
|
"text": ".",
|
||||||
"logprob": -1.8267672,
|
"logprob": -1.8267672,
|
||||||
"text": "."
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1587,
|
"id": 1587,
|
||||||
|
"text": "get",
|
||||||
"logprob": -2.4674969,
|
"logprob": -2.4674969,
|
||||||
"text": "get"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11,
|
"id": 11,
|
||||||
|
"text": "(",
|
||||||
"logprob": -1.906001,
|
"logprob": -1.906001,
|
||||||
"text": "("
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
"text": "\"",
|
||||||
"logprob": -1.2279545,
|
"logprob": -1.2279545,
|
||||||
"text": "\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4899,
|
"id": 4899,
|
||||||
|
"text": "action",
|
||||||
"logprob": -4.170299,
|
"logprob": -4.170299,
|
||||||
"text": "action"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
"text": "\"",
|
||||||
"logprob": -0.32478866,
|
"logprob": -0.32478866,
|
||||||
"text": "\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12,
|
"id": 12,
|
||||||
|
"text": ")",
|
||||||
"logprob": -1.0773665,
|
"logprob": -1.0773665,
|
||||||
"text": ")"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
|
"text": ";",
|
||||||
"logprob": -0.27640742,
|
"logprob": -0.27640742,
|
||||||
"text": ";"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 837,
|
"id": 837,
|
||||||
|
"text": "\n ",
|
||||||
"logprob": -1.6970354,
|
"logprob": -1.6970354,
|
||||||
"text": "\n "
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1320,
|
"id": 1320,
|
||||||
|
"text": " if",
|
||||||
"logprob": -1.4495516,
|
"logprob": -1.4495516,
|
||||||
"text": " if"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 375,
|
"id": 375,
|
||||||
|
"text": " (",
|
||||||
"logprob": -0.23609057,
|
"logprob": -0.23609057,
|
||||||
"text": " ("
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4899,
|
"id": 4899,
|
||||||
|
"text": "action",
|
||||||
"logprob": -1.1916996,
|
"logprob": -1.1916996,
|
||||||
"text": "action"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3535,
|
"id": 3535,
|
||||||
|
"text": " ==",
|
||||||
"logprob": -0.8918753,
|
"logprob": -0.8918753,
|
||||||
"text": " =="
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5109,
|
"id": 5109,
|
||||||
|
"text": " null",
|
||||||
"logprob": -0.3933342,
|
"logprob": -0.3933342,
|
||||||
"text": " null"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12,
|
"id": 12,
|
||||||
|
"text": ")",
|
||||||
"logprob": -0.43212673,
|
"logprob": -0.43212673,
|
||||||
"text": ")"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 731,
|
"id": 731,
|
||||||
|
"text": " {",
|
||||||
"logprob": -0.17702064,
|
"logprob": -0.17702064,
|
||||||
"text": " {"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1260,
|
"id": 1260,
|
||||||
|
"text": "\n ",
|
||||||
"logprob": -0.07027565,
|
"logprob": -0.07027565,
|
||||||
"text": "\n "
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10519,
|
"id": 10519,
|
||||||
|
"text": " throw",
|
||||||
"logprob": -1.3915029,
|
"logprob": -1.3915029,
|
||||||
"text": " throw"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2084,
|
"id": 2084,
|
||||||
|
"text": " new",
|
||||||
"logprob": -0.04201372,
|
"logprob": -0.04201372,
|
||||||
"text": " new"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 150858,
|
"id": 150858,
|
||||||
|
"text": " RuntimeException",
|
||||||
"logprob": -1.7329919,
|
"logprob": -1.7329919,
|
||||||
"text": " RuntimeException"
|
"special": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
|
||||||
}
|
}
|
@ -14,6 +14,7 @@ pub struct Token {
|
|||||||
id: u32,
|
id: u32,
|
||||||
text: String,
|
text: String,
|
||||||
logprob: Option<f32>,
|
logprob: Option<f32>,
|
||||||
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
|||||||
{
|
{
|
||||||
assert_eq!(token.id, expected_token.id);
|
assert_eq!(token.id, expected_token.id);
|
||||||
assert_eq!(token.text, expected_token.text);
|
assert_eq!(token.text, expected_token.text);
|
||||||
|
assert_eq!(token.special, expected_token.special);
|
||||||
if let Some(logprob) = token.logprob {
|
if let Some(logprob) = token.logprob {
|
||||||
let expected_logprob = expected_token.logprob.unwrap();
|
let expected_logprob = expected_token.logprob.unwrap();
|
||||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||||
|
@ -1,117 +1,137 @@
|
|||||||
{
|
{
|
||||||
|
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
|
||||||
"details": {
|
"details": {
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
|
"seed": null,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
"logprob": null,
|
"text": "<pad>",
|
||||||
"text": "<pad>"
|
"logprob": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": " ",
|
||||||
"logprob": -1.3656927,
|
"logprob": -1.3656927,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 215100,
|
"id": 215100,
|
||||||
|
"text": "\"\"\"",
|
||||||
"logprob": -2.6551573,
|
"logprob": -2.6551573,
|
||||||
"text": "\"\"\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 46138,
|
"id": 46138,
|
||||||
|
"text": "Test",
|
||||||
"logprob": -1.8059857,
|
"logprob": -1.8059857,
|
||||||
"text": "Test"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": " the",
|
||||||
"logprob": -1.2102449,
|
"logprob": -1.2102449,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": " ",
|
||||||
"logprob": -1.6057279,
|
"logprob": -1.6057279,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -3.6060903,
|
"logprob": -3.6060903,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
|
"text": " of",
|
||||||
"logprob": -0.5270343,
|
"logprob": -0.5270343,
|
||||||
"text": "of"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": " the",
|
||||||
"logprob": -0.62522805,
|
"logprob": -0.62522805,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": " ",
|
||||||
"logprob": -1.4069618,
|
"logprob": -1.4069618,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -2.621994,
|
"logprob": -2.621994,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
|
"text": " of",
|
||||||
"logprob": -1.3172221,
|
"logprob": -1.3172221,
|
||||||
"text": "of"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
|
"text": " the",
|
||||||
"logprob": -0.3501925,
|
"logprob": -0.3501925,
|
||||||
"text": "the"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": " ",
|
||||||
"logprob": -0.7219573,
|
"logprob": -0.7219573,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 49076,
|
"id": 49076,
|
||||||
|
"text": "contents",
|
||||||
"logprob": -1.0494149,
|
"logprob": -1.0494149,
|
||||||
"text": "contents"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 260,
|
"id": 260,
|
||||||
|
"text": ".",
|
||||||
"logprob": -1.0803378,
|
"logprob": -1.0803378,
|
||||||
"text": "."
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
|
"text": " ",
|
||||||
"logprob": -0.32933083,
|
"logprob": -0.32933083,
|
||||||
"text": ""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 215100,
|
"id": 215100,
|
||||||
|
"text": "\"\"\"",
|
||||||
"logprob": -0.11268901,
|
"logprob": -0.11268901,
|
||||||
"text": "\"\"\""
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2978,
|
"id": 2978,
|
||||||
|
"text": " test",
|
||||||
"logprob": -1.5846587,
|
"logprob": -1.5846587,
|
||||||
"text": "test"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 290,
|
"id": 290,
|
||||||
|
"text": "_",
|
||||||
"logprob": -0.49796978,
|
"logprob": -0.49796978,
|
||||||
"text": "_"
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4125,
|
"id": 4125,
|
||||||
|
"text": "test",
|
||||||
"logprob": -2.0026445,
|
"logprob": -2.0026445,
|
||||||
"text": "test"
|
"special": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
|
||||||
}
|
}
|
@ -34,12 +34,16 @@ message NextTokenChooserParameters {
|
|||||||
uint32 top_k = 2;
|
uint32 top_k = 2;
|
||||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
float top_p = 3;
|
float top_p = 3;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float typical_p = 4;
|
||||||
/// apply sampling on the logits
|
/// apply sampling on the logits
|
||||||
bool do_sample = 4;
|
bool do_sample = 5;
|
||||||
/// random seed for sampling
|
/// random seed for sampling
|
||||||
uint64 seed = 5;
|
uint64 seed = 6;
|
||||||
/// repetition penalty
|
/// repetition penalty
|
||||||
float repetition_penalty = 6;
|
float repetition_penalty = 7;
|
||||||
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
|
bool watermark = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
@ -54,12 +58,10 @@ message Request {
|
|||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// The generation context
|
/// The generation context
|
||||||
string inputs = 2;
|
string inputs = 2;
|
||||||
/// The number of tokens inside inputs
|
|
||||||
uint32 input_length = 3;
|
|
||||||
/// Next Token Chooser Parameters
|
/// Next Token Chooser Parameters
|
||||||
NextTokenChooserParameters parameters = 4;
|
NextTokenChooserParameters parameters = 3;
|
||||||
/// Stopping Criteria Parameters
|
/// Stopping Criteria Parameters
|
||||||
StoppingCriteriaParameters stopping_parameters = 5;
|
StoppingCriteriaParameters stopping_parameters = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
@ -108,8 +110,10 @@ message Generation {
|
|||||||
float token_logprob = 4;
|
float token_logprob = 4;
|
||||||
/// Text
|
/// Text
|
||||||
string token_text = 5;
|
string token_text = 5;
|
||||||
|
/// Is it a special token
|
||||||
|
bool token_is_special = 6;
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
GeneratedText generated_text = 6;
|
GeneratedText generated_text = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
description = "Text Generation Webserver"
|
description = "Text Generation Webserver"
|
||||||
@ -19,17 +19,21 @@ axum-tracing-opentelemetry = "0.9.0"
|
|||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
futures = "0.3.26"
|
futures = "0.3.26"
|
||||||
|
metrics = "0.20.1"
|
||||||
|
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.14", features = [] }
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
serde_json = "1.0.93"
|
serde_json = "1.0.93"
|
||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
tokenizers = "0.13.2"
|
tokenizers = "0.13.2"
|
||||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.11"
|
tokio-stream = "0.1.11"
|
||||||
|
tower-http = { version = "0.3.5", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "grpc-metadata"
|
name = "grpc-metadata"
|
||||||
version = "0.1.0"
|
version = "0.4.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::GenerateRequest;
|
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
|
use futures::future::try_join_all;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
@ -81,6 +81,7 @@ impl Infer {
|
|||||||
.limit_concurrent_requests
|
.limit_concurrent_requests
|
||||||
.try_acquire_owned()
|
.try_acquire_owned()
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
@ -138,7 +139,7 @@ impl Infer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens.logprobs.into_iter())
|
.zip(tokens.logprobs.into_iter())
|
||||||
.zip(tokens.texts.into_iter())
|
.zip(tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token { id, text, logprob })
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
@ -172,10 +173,48 @@ impl Infer {
|
|||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
||||||
|
/// the highest log probability per token
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub(crate) async fn generate_best_of(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
best_of: usize,
|
||||||
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
|
// validate best_of parameter separately
|
||||||
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
|
// create multiple generate requests
|
||||||
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||||
|
|
||||||
|
// get the sequence with the highest log probability per token
|
||||||
|
let mut max_index = 0;
|
||||||
|
let mut max_logprob: f32 = f32::MIN;
|
||||||
|
|
||||||
|
for (i, response) in infer_responses.iter().enumerate() {
|
||||||
|
// mean logprobs of the generated tokens
|
||||||
|
let sequence_logprob = response
|
||||||
|
.tokens
|
||||||
|
.iter()
|
||||||
|
.map(|token| token.logprob)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ response.tokens.len() as f32;
|
||||||
|
|
||||||
|
// set best sequence
|
||||||
|
if sequence_logprob > max_logprob {
|
||||||
|
max_index = i;
|
||||||
|
max_logprob = sequence_logprob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let best_response = infer_responses.remove(max_index);
|
||||||
|
Ok((best_response, infer_responses))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
@ -190,7 +229,11 @@ async fn batching_task(
|
|||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
) {
|
) {
|
||||||
// Minimum batch size after which we try to add more requests
|
// Minimum batch size after which we try to add more requests
|
||||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
let limit_min_batch_size = if max_batch_size > 1 {
|
||||||
|
(max_batch_size / 2) as u32
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
@ -201,7 +244,7 @@ async fn batching_task(
|
|||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
@ -212,6 +255,7 @@ async fn batching_task(
|
|||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// If the current batch is too small, we try to add more requests to it
|
||||||
if batch_size <= limit_min_batch_size {
|
if batch_size <= limit_min_batch_size {
|
||||||
@ -234,15 +278,15 @@ async fn batching_task(
|
|||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
let entry_waiting_span =
|
let entry_waiting_span =
|
||||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
entry_waiting_span.follows_from(&span);
|
entry_waiting_span.follows_from(&span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
@ -262,35 +306,66 @@ async fn batching_task(
|
|||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span =
|
let entry_batch_span =
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn wrap_future(
|
async fn prefill(
|
||||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
send_generations(generations, entries);
|
send_generations(generations, entries);
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<Batch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<Batch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch)) => {
|
||||||
|
send_generations(generations, entries);
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -303,6 +378,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
@ -340,6 +416,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||||||
id: generation.token_id,
|
id: generation.token_id,
|
||||||
text: generation.token_text,
|
text: generation.token_text,
|
||||||
logprob: generation.token_logprob,
|
logprob: generation.token_logprob,
|
||||||
|
special: generation.token_is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
@ -388,7 +465,7 @@ pub(crate) enum InferStreamResponse {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct InferResponse {
|
pub(crate) struct InferResponse {
|
||||||
pub(crate) prefill: Vec<Token>,
|
pub(crate) prefill: Vec<PrefillToken>,
|
||||||
pub(crate) tokens: Vec<Token>,
|
pub(crate) tokens: Vec<Token>,
|
||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
@ -406,3 +483,14 @@ pub enum InferError {
|
|||||||
#[error("Incomplete generation")]
|
#[error("Incomplete generation")]
|
||||||
IncompleteGeneration,
|
IncompleteGeneration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl InferError {
|
||||||
|
pub(crate) fn error_type(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
InferError::GenerationError(_) => "generation",
|
||||||
|
InferError::Overloaded(_) => "overloaded",
|
||||||
|
InferError::ValidationError(_) => "validation",
|
||||||
|
InferError::IncompleteGeneration => "incomplete_generation",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -12,6 +12,9 @@ use validation::Validation;
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
|
pub best_of: Option<usize>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
@ -40,39 +43,64 @@ pub(crate) struct GenerateParameters {
|
|||||||
example = 0.95
|
example = 0.95
|
||||||
)]
|
)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
#[serde(default = "default_do_sample")]
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0.0,
|
||||||
|
maximum = 1.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = 0.95
|
||||||
|
)]
|
||||||
|
pub typical_p: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
|
pub return_full_text: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub truncate: Option<usize>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(default = "false", example = true)]
|
||||||
|
pub watermark: bool,
|
||||||
|
#[serde(default)]
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = "null"
|
||||||
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_do_sample() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> u32 {
|
||||||
20
|
20
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
do_sample: default_do_sample(),
|
typical_p: None,
|
||||||
|
do_sample: false,
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
stop: vec![],
|
return_full_text: None,
|
||||||
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
}
|
}
|
||||||
@ -86,14 +114,46 @@ pub(crate) struct GenerateRequest {
|
|||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct CompatGenerateRequest {
|
||||||
|
#[schema(example = "My name is Olivier and I")]
|
||||||
|
pub inputs: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
#[serde(default)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CompatGenerateRequest> for GenerateRequest {
|
||||||
|
fn from(req: CompatGenerateRequest) -> Self {
|
||||||
|
Self {
|
||||||
|
inputs: req.inputs,
|
||||||
|
parameters: req.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct PrefillToken {
|
||||||
|
#[schema(example = 0)]
|
||||||
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
|
text: String,
|
||||||
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
|
logprob: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct Token {
|
pub struct Token {
|
||||||
#[schema(example = 0)]
|
#[schema(example = 0)]
|
||||||
id: u32,
|
id: u32,
|
||||||
#[schema(example = "test")]
|
#[schema(example = "test")]
|
||||||
text: String,
|
text: String,
|
||||||
#[schema(nullable = true, example = -0.34)]
|
#[schema(nullable = true, example = - 0.34)]
|
||||||
logprob: f32,
|
logprob: f32,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
special: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -108,16 +168,32 @@ pub(crate) enum FinishReason {
|
|||||||
StopSequence,
|
StopSequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct BestOfSequence {
|
||||||
|
#[schema(example = "test")]
|
||||||
|
pub generated_text: String,
|
||||||
|
#[schema(example = "length")]
|
||||||
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
pub prefill: Vec<PrefillToken>,
|
||||||
|
pub tokens: Vec<Token>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
#[schema(example = 1)]
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Option<Vec<Token>>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Vec<Token>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -134,7 +210,7 @@ pub(crate) struct StreamDetails {
|
|||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
#[schema(example = 1)]
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,6 +225,6 @@ pub(crate) struct StreamResponse {
|
|||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
#[schema(inline)]
|
|
||||||
pub error: String,
|
pub error: String,
|
||||||
|
pub error_type: String,
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
/// Text Generation Inference webserver entrypoint
|
/// Text Generation Inference webserver entrypoint
|
||||||
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
@ -7,9 +8,11 @@ use opentelemetry::sdk::Resource;
|
|||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::{EnvFilter, Layer};
|
||||||
@ -20,8 +23,14 @@ use tracing_subscriber::{EnvFilter, Layer};
|
|||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
#[clap(default_value = "1512", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
#[clap(default_value = "32", long, env)]
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
@ -38,6 +47,8 @@ struct Args {
|
|||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), std::io::Error> {
|
||||||
@ -46,7 +57,10 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
port,
|
port,
|
||||||
@ -55,17 +69,37 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
panic!("validation_workers must be > 0");
|
panic!("validation_workers must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// CORS allowed origins
|
||||||
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
// Finally, convert to AllowOrigin
|
||||||
|
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||||
|
AllowOrigin::list(
|
||||||
|
cors_allow_origin
|
||||||
|
.iter()
|
||||||
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
//
|
let local_path = Path::new(&tokenizer_name);
|
||||||
|
let tokenizer =
|
||||||
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
|
{
|
||||||
|
// Load local tokenizer
|
||||||
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||||
|
} else {
|
||||||
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
@ -75,6 +109,27 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
|
// Get pipeline tag
|
||||||
|
let model_info = reqwest::get(format!(
|
||||||
|
"https://huggingface.co/api/models/{tokenizer_name}"
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.expect("Could not connect to hf.co")
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.expect("error when retrieving model info from hf.co");
|
||||||
|
let model_info: serde_json::Value =
|
||||||
|
serde_json::from_str(&model_info).expect("unable to parse model info");
|
||||||
|
|
||||||
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
|
let compat_return_full_text = match model_info.get("pipeline_tag") {
|
||||||
|
None => {
|
||||||
|
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
|
||||||
|
};
|
||||||
|
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
@ -91,14 +146,19 @@ fn main() -> Result<(), std::io::Error> {
|
|||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
addr,
|
addr,
|
||||||
|
cors_allow_origin,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -132,6 +132,7 @@ impl State {
|
|||||||
// Push entry in the queue
|
// Push entry in the queue
|
||||||
self.entries.push((self.next_id, entry));
|
self.entries.push((self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
|
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
@ -164,7 +165,8 @@ impl State {
|
|||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span =
|
let entry_batch_span =
|
||||||
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
|
||||||
// Add relationship
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
entry_batch_span.follows_from(&next_batch_span);
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
@ -172,7 +174,6 @@ impl State {
|
|||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
input_length: entry.request.input_length,
|
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
});
|
||||||
@ -190,6 +191,8 @@ impl State {
|
|||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
|
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||||
|
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -223,14 +226,15 @@ mod tests {
|
|||||||
Entry {
|
Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: "".to_string(),
|
||||||
input_length: 0,
|
|
||||||
parameters: NextTokenChooserParameters {
|
parameters: NextTokenChooserParameters {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
top_p: 0.0,
|
top_p: 0.0,
|
||||||
|
typical_p: 0.0,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
max_new_tokens: 0,
|
max_new_tokens: 0,
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
|
||||||
|
StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
@ -19,29 +22,61 @@ use tokenizers::Tokenizer;
|
|||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
/// Compatibility route with api-inference and AzureML
|
||||||
|
#[instrument(skip(infer))]
|
||||||
|
async fn compat_generate(
|
||||||
|
default_return_full_text: Extension<bool>,
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
req: Json<CompatGenerateRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let mut req = req.0;
|
||||||
|
|
||||||
|
// default return_full_text given the pipeline_tag
|
||||||
|
if req.parameters.return_full_text.is_none() {
|
||||||
|
req.parameters.return_full_text = Some(default_return_full_text.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// switch on stream
|
||||||
|
if req.stream {
|
||||||
|
Ok(generate_stream(infer, Json(req.into()))
|
||||||
|
.await
|
||||||
|
.into_response())
|
||||||
|
} else {
|
||||||
|
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||||
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
Ok((headers, Json(vec![generation.0])).into_response())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
// be a bit too slow for a health check.
|
// be a bit too slow for a health check.
|
||||||
// What we should do instead if check if the gRPC channels are still healthy.
|
// What we should do instead is check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
// Send a small inference request
|
// Send a small inference request
|
||||||
infer
|
infer
|
||||||
.generate(GenerateRequest {
|
.generate(GenerateRequest {
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: Vec::new(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
},
|
},
|
||||||
@ -57,15 +92,15 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
path = "/generate",
|
path = "/generate",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = [GenerateResponse]),
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json!({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json!({"error": "Model is overloaded"})),
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json!({"error": "Input validation error"})),
|
example = json ! ({"error": "Input validation error"})),
|
||||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json!({"error": "Incomplete generation"})),
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
@ -82,23 +117,64 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
// Inference
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
|
let mut add_prompt = None;
|
||||||
|
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||||
|
add_prompt = Some(req.0.inputs.clone());
|
||||||
|
}
|
||||||
|
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
let response = infer.generate(req.0).await?;
|
|
||||||
|
// Inference
|
||||||
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
|
Some(best_of) if best_of > 1 => {
|
||||||
|
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||||
|
(response, Some(best_of_responses))
|
||||||
|
}
|
||||||
|
_ => (infer.generate(req.0).await?, None),
|
||||||
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => {
|
||||||
|
// convert best_of_responses
|
||||||
|
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
|
||||||
|
responses
|
||||||
|
.into_iter()
|
||||||
|
.map(|response: InferResponse| {
|
||||||
|
// Add prompt if return_full_text
|
||||||
|
let mut output_text = response.generated_text.text;
|
||||||
|
if let Some(prompt) = &add_prompt {
|
||||||
|
output_text = prompt.clone() + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
BestOfSequence {
|
||||||
|
generated_text: output_text,
|
||||||
|
finish_reason: FinishReason::from(
|
||||||
|
response.generated_text.finish_reason,
|
||||||
|
),
|
||||||
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
|
prefill: response.prefill,
|
||||||
|
tokens: response.tokens,
|
||||||
|
seed: response.generated_text.seed,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
Some(Details {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: Some(response.prefill),
|
prefill: response.prefill,
|
||||||
tokens: Some(response.tokens),
|
tokens: response.tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
}),
|
best_of_sequences,
|
||||||
|
})
|
||||||
|
}
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -111,6 +187,15 @@ async fn generate(
|
|||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-time",
|
||||||
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-characters",
|
||||||
|
compute_characters.to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-total-time",
|
"x-total-time",
|
||||||
total_time.as_millis().to_string().parse().unwrap(),
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
@ -141,9 +226,26 @@ async fn generate(
|
|||||||
span.record("seed", format!("{:?}", response.generated_text.seed));
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
metrics::increment_counter!("tgi_request_success");
|
||||||
|
metrics::histogram!("tgi_request_duration", total_time);
|
||||||
|
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||||
|
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||||
|
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||||
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||||
|
metrics::histogram!(
|
||||||
|
"tgi_request_generated_tokens",
|
||||||
|
response.generated_text.generated_tokens as f64
|
||||||
|
);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
|
let mut output_text = response.generated_text.text;
|
||||||
|
if let Some(prompt) = add_prompt {
|
||||||
|
output_text = prompt + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
let response = GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: response.generated_text.text,
|
generated_text: output_text,
|
||||||
details,
|
details,
|
||||||
};
|
};
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
@ -156,20 +258,20 @@ async fn generate(
|
|||||||
path = "/generate_stream",
|
path = "/generate_stream",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = [StreamResponse],
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type="text/event-stream "),
|
content_type = "text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json!({"error": "Request failed during generation"}),
|
example = json ! ({"error": "Request failed during generation"}),
|
||||||
content_type="text/event-stream "),
|
content_type = "text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json!({"error": "Model is overloaded"}),
|
example = json ! ({"error": "Model is overloaded"}),
|
||||||
content_type="text/event-stream "),
|
content_type = "text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json!({"error": "Input validation error"}),
|
example = json ! ({"error": "Input validation error"}),
|
||||||
content_type="text/event-stream "),
|
content_type = "text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json!({"error": "Incomplete generation"}),
|
example = json ! ({"error": "Incomplete generation"}),
|
||||||
content_type="text/event-stream "),
|
content_type = "text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
@ -186,16 +288,35 @@ async fn generate(
|
|||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
) -> (
|
||||||
|
HeaderMap,
|
||||||
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-characters",
|
||||||
|
compute_characters.to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
// Inference
|
// Inference
|
||||||
let mut end_reached = false;
|
let mut end_reached = false;
|
||||||
let mut error = false;
|
let mut error = false;
|
||||||
|
|
||||||
|
let mut add_prompt = None;
|
||||||
|
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||||
|
add_prompt = Some(req.0.inputs.clone());
|
||||||
|
}
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
|
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||||
|
if best_of == 1 {
|
||||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
Ok(mut response_stream) => {
|
Ok(mut response_stream) => {
|
||||||
// Server-Sent Event stream
|
// Server-Sent Event stream
|
||||||
@ -241,30 +362,47 @@ async fn generate_stream(
|
|||||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
// Tracing metadata
|
// Tracing metadata
|
||||||
span.record("total_time", format!("{:?}", total_time));
|
span.record("total_time", format!("{total_time:?}"));
|
||||||
span.record("validation_time", format!("{:?}", validation_time));
|
span.record("validation_time", format!("{validation_time:?}"));
|
||||||
span.record("queue_time", format!("{:?}", queue_time));
|
span.record("queue_time", format!("{queue_time:?}"));
|
||||||
span.record("inference_time", format!("{:?}", inference_time));
|
span.record("inference_time", format!("{inference_time:?}"));
|
||||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
metrics::increment_counter!("tgi_request_success");
|
||||||
|
metrics::histogram!("tgi_request_duration", total_time);
|
||||||
|
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||||
|
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||||
|
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||||
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||||
|
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
|
|
||||||
|
let mut output_text = generated_text.text;
|
||||||
|
if let Some(prompt) = add_prompt {
|
||||||
|
output_text = prompt + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
generated_text: Some(generated_text.text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
};
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
yield Ok(Event::from(err))
|
yield Ok(Event::from(err));
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -272,32 +410,55 @@ async fn generate_stream(
|
|||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
yield Ok(Event::from(err))
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check if generation reached the end
|
// Check if generation reached the end
|
||||||
// Skip if we already sent an error
|
// Skip if we already sent an error
|
||||||
if !end_reached && !error {
|
if !end_reached && !error {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err))
|
yield Ok(Event::from(err));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prometheus metrics scrape endpoint
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/metrics",
|
||||||
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
|
)]
|
||||||
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
|
prom_handle.render()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
|
allow_origin: Option<AllowOrigin>,
|
||||||
) {
|
) {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
@ -305,13 +466,16 @@ pub async fn run(
|
|||||||
paths(
|
paths(
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
|
BestOfSequence,
|
||||||
Details,
|
Details,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
@ -333,7 +497,14 @@ pub async fn run(
|
|||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(
|
||||||
|
validation_workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -342,16 +513,33 @@ pub async fn run(
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Prometheus handler
|
||||||
|
let builder = PrometheusBuilder::new();
|
||||||
|
let prom_handle = builder
|
||||||
|
.install_recorder()
|
||||||
|
.expect("failed to install metrics recorder");
|
||||||
|
|
||||||
|
// CORS layer
|
||||||
|
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||||
|
let cors_layer = CorsLayer::new()
|
||||||
|
.allow_methods([Method::GET, Method::POST])
|
||||||
|
.allow_headers([http::header::CONTENT_TYPE])
|
||||||
|
.allow_origin(allow_origin);
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
.route("/", post(generate))
|
.route("/", post(compat_generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
|
.route("/metrics", get(metrics))
|
||||||
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(opentelemetry_tracing_layer());
|
.layer(Extension(prom_handle))
|
||||||
|
.layer(opentelemetry_tracing_layer())
|
||||||
|
.layer(cors_layer);
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
@ -415,6 +603,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||||||
status_code,
|
status_code,
|
||||||
Json(ErrorResponse {
|
Json(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: err.to_string(),
|
||||||
|
error_type: err.error_type().to_string(),
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -425,6 +614,7 @@ impl From<InferError> for Event {
|
|||||||
Event::default()
|
Event::default()
|
||||||
.json_data(ErrorResponse {
|
.json_data(ErrorResponse {
|
||||||
error: err.to_string(),
|
error: err.to_string(),
|
||||||
|
error_type: err.error_type().to_string(),
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
@ -5,33 +6,44 @@ use rand::Rng;
|
|||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokenizers::TruncationDirection;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
const MAX_MAX_NEW_TOKENS: u32 = 512;
|
|
||||||
const MAX_STOP_SEQUENCES: usize = 4;
|
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
|
/// maximum value for the best_of parameter
|
||||||
|
#[allow(dead_code)]
|
||||||
|
max_best_of: usize,
|
||||||
/// Channel to communicate with the background validation task
|
/// Channel to communicate with the background validation task
|
||||||
sender: mpsc::Sender<ValidationRequest>,
|
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
pub(crate) fn new(
|
||||||
|
workers: usize,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
max_best_of: usize,
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Launch background validation task
|
// Launch background validation task
|
||||||
tokio::spawn(validation_task(
|
tokio::spawn(validation_task(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
validation_receiver,
|
validation_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
max_best_of,
|
||||||
sender: validation_sender,
|
sender: validation_sender,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -48,12 +60,25 @@ impl Validation {
|
|||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
self.sender
|
self.sender
|
||||||
.send((request, sender, Span::current()))
|
.send((request, sender, Span::current()))
|
||||||
.await
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
receiver.await.unwrap()
|
receiver.await.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate the best_of parameter
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
|
||||||
|
if self.max_best_of == 1 && best_of != 1 {
|
||||||
|
return Err(ValidationError::BestOfDisabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
if best_of > self.max_best_of {
|
||||||
|
return Err(ValidationError::BestOf(self.max_best_of, best_of));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(best_of)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validation task
|
/// Validation task
|
||||||
@ -61,8 +86,10 @@ impl Validation {
|
|||||||
async fn validation_task(
|
async fn validation_task(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
max_total_tokens: usize,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
|
||||||
) {
|
) {
|
||||||
let mut workers_senders = Vec::with_capacity(workers);
|
let mut workers_senders = Vec::with_capacity(workers);
|
||||||
|
|
||||||
@ -75,7 +102,13 @@ async fn validation_task(
|
|||||||
|
|
||||||
// Spawn worker
|
// Spawn worker
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
validation_worker(
|
||||||
|
tokenizer_clone,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
worker_receiver,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +128,9 @@ async fn validation_task(
|
|||||||
/// the tokenizer
|
/// the tokenizer
|
||||||
fn validation_worker(
|
fn validation_worker(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||||
) {
|
) {
|
||||||
// Seed rng
|
// Seed rng
|
||||||
@ -106,7 +141,16 @@ fn validation_worker(
|
|||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(
|
.send(
|
||||||
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
|
validate(
|
||||||
|
request,
|
||||||
|
&tokenizer,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
&mut rng,
|
||||||
|
)
|
||||||
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
}),
|
}),
|
||||||
@ -119,21 +163,39 @@ fn validation_worker(
|
|||||||
fn validate(
|
fn validate(
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
let GenerateParameters {
|
let GenerateParameters {
|
||||||
|
best_of,
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
|
typical_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
|
truncate,
|
||||||
seed,
|
seed,
|
||||||
|
watermark,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
// sampling must be true when best_of > 1
|
||||||
|
let best_of = best_of.unwrap_or(1);
|
||||||
|
let sampling = do_sample
|
||||||
|
|| temperature.is_some()
|
||||||
|
|| top_k.is_some()
|
||||||
|
|| top_p.is_some()
|
||||||
|
|| typical_p.is_some();
|
||||||
|
|
||||||
|
if best_of > 1 && !sampling {
|
||||||
|
return Err(BestOfSampling);
|
||||||
|
}
|
||||||
|
|
||||||
let temperature = temperature.unwrap_or(1.0);
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
if temperature <= 0.0 {
|
if temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
@ -144,30 +206,42 @@ fn validate(
|
|||||||
return Err(ValidationError::RepetitionPenalty);
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
}
|
}
|
||||||
|
|
||||||
let top_p = top_p.unwrap_or(1.0);
|
// Different because the proto default value is not a valid value
|
||||||
if top_p <= 0.0 || top_p > 1.0 {
|
// for the user
|
||||||
|
let top_p = top_p
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0.0 || value >= 1.0 {
|
||||||
return Err(ValidationError::TopP);
|
return Err(ValidationError::TopP);
|
||||||
}
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
// Different because the proto default value is 0 while it is not a valid value
|
let typical_p = typical_p
|
||||||
// for the user
|
.map(|value| {
|
||||||
let top_k: u32 = match top_k {
|
if value <= 0.0 || value >= 1.0 {
|
||||||
None => Ok(0),
|
return Err(ValidationError::TypicalP);
|
||||||
Some(top_k) => {
|
}
|
||||||
if top_k <= 0 {
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
|
let top_k: u32 = top_k
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0 {
|
||||||
return Err(ValidationError::TopK);
|
return Err(ValidationError::TopK);
|
||||||
}
|
}
|
||||||
Ok(top_k as u32)
|
Ok(value as u32)
|
||||||
}
|
})
|
||||||
}?;
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
if max_new_tokens == 0 {
|
||||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
return Err(ValidationError::MaxNewTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
if stop_sequences.len() > max_stop_sequences {
|
||||||
return Err(ValidationError::StopSequence(
|
return Err(ValidationError::StopSequence(
|
||||||
MAX_STOP_SEQUENCES,
|
max_stop_sequences,
|
||||||
stop_sequences.len(),
|
stop_sequences.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
@ -175,41 +249,82 @@ fn validate(
|
|||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
let seed = match seed {
|
let seed = match seed {
|
||||||
None => rng.gen(),
|
None => rng.gen(),
|
||||||
Some(seed) => seed,
|
Some(seed) => {
|
||||||
|
if best_of > 1 {
|
||||||
|
return Err(BestOfSeed);
|
||||||
|
}
|
||||||
|
seed
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Check if inputs is empty
|
||||||
|
if request.inputs.is_empty() {
|
||||||
|
return Err(EmptyInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if truncate is strictly positive and less than max_input_length
|
||||||
|
let truncate = truncate
|
||||||
|
.map(|value| {
|
||||||
|
if value == 0 || value > max_input_length {
|
||||||
|
return Err(ValidationError::Truncate(max_input_length, value));
|
||||||
|
}
|
||||||
|
Ok(Some(value))
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
match tokenizer.encode(request.inputs.clone(), true) {
|
let mut encoding = tokenizer
|
||||||
Ok(encoding) => {
|
.encode(request.inputs.clone(), true)
|
||||||
let input_length = encoding.len();
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
|
let (inputs, input_length) = if let Some(truncate) = truncate {
|
||||||
|
// truncate encoding and decode new inputs
|
||||||
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
|
let inputs = tokenizer
|
||||||
|
.decode(Vec::from(encoding.get_ids()), false)
|
||||||
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
(inputs, encoding.len())
|
||||||
|
} else {
|
||||||
|
(request.inputs, encoding.len())
|
||||||
|
};
|
||||||
|
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
return Err(ValidationError::InputLength(max_input_length, input_length));
|
||||||
} else {
|
}
|
||||||
|
|
||||||
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
if total_tokens > max_total_tokens {
|
||||||
|
return Err(ValidationError::MaxTotalTokens(
|
||||||
|
max_total_tokens,
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Return ValidGenerateRequest
|
// Return ValidGenerateRequest
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
|
typical_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
|
watermark,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs: request.inputs,
|
inputs,
|
||||||
input_length: input_length as u32,
|
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValidationRequest = (
|
type ValidationRequest = (
|
||||||
@ -221,26 +336,43 @@ type ValidationRequest = (
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
pub input_length: u32,
|
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
#[error("temperature must be strictly positive")]
|
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
||||||
|
BestOf(usize, usize),
|
||||||
|
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||||
|
BestOfDisabled,
|
||||||
|
#[error("you must use sampling when `best_of` is > 1")]
|
||||||
|
BestOfSampling,
|
||||||
|
#[error("`seed` must not be set when `best_of` > 1")]
|
||||||
|
BestOfSeed,
|
||||||
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
|
BestOfStream,
|
||||||
|
#[error("`temperature` must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("repetition_penalty must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
RepetitionPenalty,
|
RepetitionPenalty,
|
||||||
#[error("top_p must be > 0.0 and <= 1.0")]
|
#[error("`top_p` must be > 0.0 and < 1.0")]
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("`top_k` must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
#[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
|
||||||
MaxNewTokens(u32),
|
Truncate(usize, usize),
|
||||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
|
TypicalP,
|
||||||
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
|
MaxNewTokens,
|
||||||
|
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||||
|
MaxTotalTokens(usize, usize, u32),
|
||||||
|
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
#[error("stop supports up to {0} stop sequences. Given: {1}")]
|
#[error("`inputs` cannot be empty")]
|
||||||
|
EmptyInput,
|
||||||
|
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
|
||||||
StopSequence(usize, usize),
|
StopSequence(usize, usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
|
4
server/.gitignore
vendored
4
server/.gitignore
vendored
@ -1,7 +1,7 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
text_generation/__pycache__/
|
text_generation_server/__pycache__/
|
||||||
text_generation/pb/__pycache__/
|
text_generation_server/pb/__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
@ -1,20 +1,22 @@
|
|||||||
|
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
|
||||||
|
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.51.1 --no-cache-dir
|
pip install grpcio-tools==1.51.1 --no-cache-dir
|
||||||
mkdir text_generation/pb || true
|
mkdir text_generation_server/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
|
||||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-transformers:
|
install-transformers:
|
||||||
# Install specific version of transformers with custom cuda kernels
|
# Install specific version of transformers with custom cuda kernels
|
||||||
pip uninstall transformers -y || true
|
pip uninstall transformers -y || true
|
||||||
rm -rf transformers || true
|
rm -rf transformers || true
|
||||||
rm -rf transformers-text_generation_inference || true
|
rm -rf transformers-$(transformers_commit) || true
|
||||||
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
|
curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip
|
||||||
unzip text_generation_inference.zip
|
unzip $(transformers_commit).zip
|
||||||
rm text_generation_inference.zip
|
rm $(transformers_commit).zip
|
||||||
mv transformers-text_generation_inference transformers
|
mv transformers-$(transformers_commit) transformers
|
||||||
cd transformers && python setup.py install
|
cd transformers && python setup.py install
|
||||||
|
|
||||||
install-torch:
|
install-torch:
|
||||||
@ -26,4 +28,4 @@ install: gen-server install-torch install-transformers
|
|||||||
pip install -e . --no-cache-dir
|
pip install -e . --no-cache-dir
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
417
server/poetry.lock
generated
417
server/poetry.lock
generated
@ -145,30 +145,30 @@ testing = ["protobuf (>=3.6.0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio"
|
name = "grpcio"
|
||||||
version = "1.51.1"
|
version = "1.51.3"
|
||||||
description = "HTTP/2-based RPC framework"
|
description = "HTTP/2-based RPC framework"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
protobuf = ["grpcio-tools (>=1.51.1)"]
|
protobuf = ["grpcio-tools (>=1.51.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio-reflection"
|
name = "grpcio-reflection"
|
||||||
version = "1.51.1"
|
version = "1.51.3"
|
||||||
description = "Standard Protobuf Reflection Service for gRPC"
|
description = "Standard Protobuf Reflection Service for gRPC"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
grpcio = ">=1.51.1"
|
grpcio = ">=1.51.3"
|
||||||
protobuf = ">=4.21.6"
|
protobuf = ">=4.21.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio-status"
|
name = "grpcio-status"
|
||||||
version = "1.51.1"
|
version = "1.51.3"
|
||||||
description = "Status proto mapping for gRPC"
|
description = "Status proto mapping for gRPC"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -176,22 +176,30 @@ python-versions = ">=3.6"
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
googleapis-common-protos = ">=1.5.5"
|
googleapis-common-protos = ">=1.5.5"
|
||||||
grpcio = ">=1.51.1"
|
grpcio = ">=1.51.3"
|
||||||
protobuf = ">=4.21.6"
|
protobuf = ">=4.21.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio-tools"
|
name = "grpcio-tools"
|
||||||
version = "1.51.1"
|
version = "1.51.3"
|
||||||
description = "Protobuf code generator for gRPC"
|
description = "Protobuf code generator for gRPC"
|
||||||
category = "dev"
|
category = "dev"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
grpcio = ">=1.51.1"
|
grpcio = ">=1.51.3"
|
||||||
protobuf = ">=4.21.6,<5.0dev"
|
protobuf = ">=4.21.6,<5.0dev"
|
||||||
setuptools = "*"
|
setuptools = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hf-transfer"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = ""
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "3.4"
|
version = "3.4"
|
||||||
@ -428,7 +436,7 @@ testing = ["pytest", "pytest-benchmark"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "protobuf"
|
name = "protobuf"
|
||||||
version = "4.21.12"
|
version = "4.22.0"
|
||||||
description = ""
|
description = ""
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -511,7 +519,7 @@ torch = ["torch"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "67.2.0"
|
version = "67.4.0"
|
||||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -567,7 +575,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.4.0"
|
version = "4.5.0"
|
||||||
description = "Backported and Experimental Type Hints for Python 3.7+"
|
description = "Backported and Experimental Type Hints for Python 3.7+"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -610,7 +618,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wrapt"
|
name = "wrapt"
|
||||||
version = "1.14.1"
|
version = "1.15.0"
|
||||||
description = "Module for decorators, wrappers and monkey patching."
|
description = "Module for decorators, wrappers and monkey patching."
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321"
|
content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
@ -760,106 +768,127 @@ grpc-interceptor = [
|
|||||||
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"},
|
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"},
|
||||||
]
|
]
|
||||||
grpcio = [
|
grpcio = [
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:cc2bece1737b44d878cc1510ea04469a8073dbbcdd762175168937ae4742dfb3"},
|
{file = "grpcio-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:f601aaeae18dab81930fb8d4f916b0da21e89bb4b5f7367ef793f46b4a76b7b0"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:e223a9793522680beae44671b9ed8f6d25bbe5ddf8887e66aebad5e0686049ef"},
|
{file = "grpcio-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:eef0450a4b5ed11feab639bf3eb1b6e23d0efa9b911bf7b06fb60e14f5f8a585"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:24ac1154c4b2ab4a0c5326a76161547e70664cd2c39ba75f00fc8a2170964ea2"},
|
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:82b0ad8ac825d4bb31bff9f638557c045f4a6d824d84b21e893968286f88246b"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4ef09f8997c4be5f3504cefa6b5c6cc3cf648274ce3cede84d4342a35d76db6"},
|
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3667c06e37d6cd461afdd51cefe6537702f3d1dc5ff4cac07e88d8b4795dc16f"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8a0b77e992c64880e6efbe0086fe54dfc0bbd56f72a92d9e48264dcd2a3db98"},
|
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3709048fe0aa23dda09b3e69849a12055790171dab9e399a72ea8f9dfbf9ac80"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:eacad297ea60c72dd280d3353d93fb1dcca952ec11de6bb3c49d12a572ba31dd"},
|
{file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:200d69857f9910f7458b39b9bcf83ee4a180591b40146ba9e49314e3a7419313"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:16c71740640ba3a882f50b01bf58154681d44b51f09a5728180a8fdc66c67bd5"},
|
{file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cd9a5e68e79c5f031500e67793048a90209711e0854a9ddee8a3ce51728de4e5"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-win32.whl", hash = "sha256:29cb97d41a4ead83b7bcad23bdb25bdd170b1e2cba16db6d3acbb090bc2de43c"},
|
{file = "grpcio-1.51.3-cp310-cp310-win32.whl", hash = "sha256:6604f614016127ae10969176bbf12eb0e03d2fb3d643f050b3b69e160d144fb4"},
|
||||||
{file = "grpcio-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:9ff42c5620b4e4530609e11afefa4a62ca91fa0abb045a8957e509ef84e54d30"},
|
{file = "grpcio-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:e95c7ccd4c5807adef1602005513bf7c7d14e5a41daebcf9d8d30d8bf51b8f81"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:bc59f7ba87972ab236f8669d8ca7400f02a0eadf273ca00e02af64d588046f02"},
|
{file = "grpcio-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:5e77ee138100f0bb55cbd147840f87ee6241dbd25f09ea7cd8afe7efff323449"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3c2b3842dcf870912da31a503454a33a697392f60c5e2697c91d133130c2c85d"},
|
{file = "grpcio-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:68a7514b754e38e8de9075f7bb4dee919919515ec68628c43a894027e40ddec4"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22b011674090594f1f3245960ced7386f6af35485a38901f8afee8ad01541dbd"},
|
{file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c1b9f8afa62ff265d86a4747a2990ec5a96e4efce5d5888f245a682d66eca47"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d680356a975d9c66a678eb2dde192d5dc427a7994fb977363634e781614f7c"},
|
{file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de30f0b417744288cec65ec8cf84b8a57995cf7f1e84ccad2704d93f05d0aae"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:094e64236253590d9d4075665c77b329d707b6fca864dd62b144255e199b4f87"},
|
{file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b69c7adc7ed60da1cb1b502853db61f453fc745f940cbcc25eb97c99965d8f41"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:257478300735ce3c98d65a930bbda3db172bd4e00968ba743e6a1154ea6edf10"},
|
{file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d81528ffe0e973dc840ec73a4132fd18b8203ad129d7410155d951a0a7e4f5d0"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-win32.whl", hash = "sha256:5a6ebcdef0ef12005d56d38be30f5156d1cb3373b52e96f147f4a24b0ddb3a9d"},
|
{file = "grpcio-1.51.3-cp311-cp311-win32.whl", hash = "sha256:040eb421613b57c696063abde405916dd830203c184c9000fc8c3b3b3c950325"},
|
||||||
{file = "grpcio-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f9b0023c2c92bebd1be72cdfca23004ea748be1813a66d684d49d67d836adde"},
|
{file = "grpcio-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:2a8e17286c4240137d933b8ca506465472248b4ce0fe46f3404459e708b65b68"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cd3baccea2bc5c38aeb14e5b00167bd4e2373a373a5e4d8d850bd193edad150c"},
|
{file = "grpcio-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d5cd1389669a847555df54177b911d9ff6f17345b2a6f19388707b7a9f724c88"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:17ec9b13cec4a286b9e606b48191e560ca2f3bbdf3986f91e480a95d1582e1a7"},
|
{file = "grpcio-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:be1bf35ce82cdbcac14e39d5102d8de4079a1c1a6a06b68e41fcd9ef64f9dd28"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:fbdbe9a849854fe484c00823f45b7baab159bdd4a46075302281998cb8719df5"},
|
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:5eed34994c095e2bf7194ffac7381c6068b057ef1e69f8f08db77771350a7566"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31bb6bc7ff145e2771c9baf612f4b9ebbc9605ccdc5f3ff3d5553de7fc0e0d79"},
|
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9a7d88082b2a17ae7bd3c2354d13bab0453899e0851733f6afa6918373f476"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e473525c28251558337b5c1ad3fa969511e42304524a4e404065e165b084c9e4"},
|
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c8abbc5f837111e7bd619612eedc223c290b0903b952ce0c7b00840ea70f14"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:6f0b89967ee11f2b654c23b27086d88ad7bf08c0b3c2a280362f28c3698b2896"},
|
{file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:165b05af77e6aecb4210ae7663e25acf234ba78a7c1c157fa5f2efeb0d6ec53c"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7942b32a291421460d6a07883033e392167d30724aa84987e6956cd15f1a21b9"},
|
{file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:54e36c2ee304ff15f2bfbdc43d2b56c63331c52d818c364e5b5214e5bc2ad9f6"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:f96ace1540223f26fbe7c4ebbf8a98e3929a6aa0290c8033d12526847b291c0f"},
|
{file = "grpcio-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:cd0daac21d9ef5e033a5100c1d3aa055bbed28bfcf070b12d8058045c4e821b1"},
|
||||||
{file = "grpcio-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f1fec3abaf274cdb85bf3878167cfde5ad4a4d97c68421afda95174de85ba813"},
|
{file = "grpcio-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:2fdd6333ce96435408565a9dbbd446212cd5d62e4d26f6a3c0feb1e3c35f1cc8"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:0e1a9e1b4a23808f1132aa35f968cd8e659f60af3ffd6fb00bcf9a65e7db279f"},
|
{file = "grpcio-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:54b0c29bdd9a3b1e1b61443ab152f060fc719f1c083127ab08d03fac5efd51be"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:6df3b63538c362312bc5fa95fb965069c65c3ea91d7ce78ad9c47cab57226f54"},
|
{file = "grpcio-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:ffaaf7e93fcb437356b5a4b23bf36e8a3d0221399ff77fd057e4bc77776a24be"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:172405ca6bdfedd6054c74c62085946e45ad4d9cec9f3c42b4c9a02546c4c7e9"},
|
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:eafbe7501a3268d05f2e450e1ddaffb950d842a8620c13ec328b501d25d2e2c3"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:506b9b7a4cede87d7219bfb31014d7b471cfc77157da9e820a737ec1ea4b0663"},
|
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:881ecb34feabf31c6b3b9bbbddd1a5b57e69f805041e5a2c6c562a28574f71c4"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb93051331acbb75b49a2a0fd9239c6ba9528f6bdc1dd400ad1cb66cf864292"},
|
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e860a3222139b41d430939bbec2ec9c3f6c740938bf7a04471a9a8caaa965a2e"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5dca372268c6ab6372d37d6b9f9343e7e5b4bc09779f819f9470cd88b2ece3c3"},
|
{file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:49ede0528e9dac7e8a9fe30b16c73b630ddd9a576bf4b675eb6b0c53ee5ca00f"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:471d39d3370ca923a316d49c8aac66356cea708a11e647e3bdc3d0b5de4f0a40"},
|
{file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6972b009638b40a448d10e1bc18e2223143b8a7aa20d7def0d78dd4af4126d12"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-win32.whl", hash = "sha256:75e29a90dc319f0ad4d87ba6d20083615a00d8276b51512e04ad7452b5c23b04"},
|
{file = "grpcio-1.51.3-cp38-cp38-win32.whl", hash = "sha256:5694448256e3cdfe5bd358f1574a3f2f51afa20cc834713c4b9788d60b7cc646"},
|
||||||
{file = "grpcio-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:f1158bccbb919da42544a4d3af5d9296a3358539ffa01018307337365a9a0c64"},
|
{file = "grpcio-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:3ea4341efe603b049e8c9a5f13c696ca37fcdf8a23ca35f650428ad3606381d9"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:59dffade859f157bcc55243714d57b286da6ae16469bf1ac0614d281b5f49b67"},
|
{file = "grpcio-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:6c677581ce129f5fa228b8f418cee10bd28dd449f3a544ea73c8ba590ee49d0b"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:dad6533411d033b77f5369eafe87af8583178efd4039c41d7515d3336c53b4f1"},
|
{file = "grpcio-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:30e09b5e0531685e176f49679b6a3b190762cc225f4565e55a899f5e14b3aa62"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:4c4423ea38a7825b8fed8934d6d9aeebdf646c97e3c608c3b0bcf23616f33877"},
|
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c831f31336e81243f85b6daff3e5e8a123302ce0ea1f2726ad752fd7a59f3aee"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dc5354e38e5adf2498312f7241b14c7ce3484eefa0082db4297189dcbe272e6"},
|
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cd2e4cefb724cab1ba2df4b7535a9980531b9ec51b4dbb5f137a1f3a3754ef0"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97d67983189e2e45550eac194d6234fc38b8c3b5396c153821f2d906ed46e0ce"},
|
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7a0d0bf44438869d307f85a54f25a896ad6b4b0ca12370f76892ad732928d87"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:538d981818e49b6ed1e9c8d5e5adf29f71c4e334e7d459bf47e9b7abb3c30e09"},
|
{file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c02abd55409bfb293371554adf6a4401197ec2133dd97727c01180889014ba4d"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9235dcd5144a83f9ca6f431bd0eccc46b90e2c22fe27b7f7d77cabb2fb515595"},
|
{file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2f8ff75e61e1227ba7a3f16b2eadbcc11d0a54096d52ab75a6b88cfbe56f55d1"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-win32.whl", hash = "sha256:aacb54f7789ede5cbf1d007637f792d3e87f1c9841f57dd51abf89337d1b8472"},
|
{file = "grpcio-1.51.3-cp39-cp39-win32.whl", hash = "sha256:6c99a73a6260bdf844b2e5ddad02dcd530310f80e1fa72c300fa19c1c7496962"},
|
||||||
{file = "grpcio-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:2b170eaf51518275c9b6b22ccb59450537c5a8555326fd96ff7391b5dd75303c"},
|
{file = "grpcio-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:22bdfac4f7f27acdd4da359b5e7e1973dc74bf1ed406729b07d0759fde2f064b"},
|
||||||
{file = "grpcio-1.51.1.tar.gz", hash = "sha256:e6dfc2b6567b1c261739b43d9c59d201c1b89e017afd9e684d85aa7a186c9f7a"},
|
{file = "grpcio-1.51.3.tar.gz", hash = "sha256:be7b2265b7527bb12109a7727581e274170766d5b3c9258d4e466f4872522d7a"},
|
||||||
]
|
]
|
||||||
grpcio-reflection = [
|
grpcio-reflection = [
|
||||||
{file = "grpcio-reflection-1.51.1.tar.gz", hash = "sha256:c07a93c0c36ef88fe475744289863b4787005eff4de0cc04213ecad718b01aae"},
|
{file = "grpcio-reflection-1.51.3.tar.gz", hash = "sha256:5adca16f0a6cd403efa3b5f8f8a493eea6a37dee9473b178fad0a60efa68bc67"},
|
||||||
{file = "grpcio_reflection-1.51.1-py3-none-any.whl", hash = "sha256:b70af764a83e42a44f65df1edb232e972ab69e72bc7fbbad481e66c29a9d8cb8"},
|
{file = "grpcio_reflection-1.51.3-py3-none-any.whl", hash = "sha256:52b037f831908468afc89c60e591d0a2bbce24a393d908c44a6d53091e90fc41"},
|
||||||
]
|
]
|
||||||
grpcio-status = [
|
grpcio-status = [
|
||||||
{file = "grpcio-status-1.51.1.tar.gz", hash = "sha256:ac2617a3095935ebd785e2228958f24b10a0d527a0c9eb5a0863c784f648a816"},
|
{file = "grpcio-status-1.51.3.tar.gz", hash = "sha256:71792c550356ba94e162c70818719ae6d67d960bdd03a9db5ff68faba2927f6c"},
|
||||||
{file = "grpcio_status-1.51.1-py3-none-any.whl", hash = "sha256:a52cbdc4b18f325bfc13d319ae7c7ae7a0fee07f3d9a005504d6097896d7a495"},
|
{file = "grpcio_status-1.51.3-py3-none-any.whl", hash = "sha256:d68d0956c16b6ea466f13c27075f126ef2cd8f0f97527d70056c64b0084357e3"},
|
||||||
]
|
]
|
||||||
grpcio-tools = [
|
grpcio-tools = [
|
||||||
{file = "grpcio-tools-1.51.1.tar.gz", hash = "sha256:8e62d23d3fed9d4f81738f98dd193dbd2e21aed4a8f0dd715e75b5439e649727"},
|
{file = "grpcio-tools-1.51.3.tar.gz", hash = "sha256:4fea28e3dd31871579a57058796a78093c75b74b74e9de2f2b7a7fd9a443d403"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:ecf1494cb695afead36995534f787761ee33fb9e116b23030113a37fe6057a83"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:779ac1ad2258b8debaa45595bfb3814806ed8880e3ea7f194e551d76a6255969"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:16b8b915625dc6eb2ea7efdfb06f1fae44a9066c9016453a2ca120c034f33090"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:83bf605fe2b3591d3c8a78646f37c72c5832c4dd84b5f92405c17cb10b136be6"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:d5e033c04b416afcddd5231b3ff94a34fb5d26fba2416eb940e69b05f22cfd25"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:35f885c5afd8e6a77d320f5a9624b439a93f9be2b87fa7b7948c1ad7b2ba0894"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a218f64e667f3332b74080bdc5440aaf0fa6700ae07a0b54ecf085aaef2aa9f"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b639fb79a4d28ce494ae40e5695bf1e2cb4a05f205fc433c46b2049ab4d99"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b186183515ad6b8584ffe4bd820b72b00f6e7d121fb1c36294edeea9092313"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c6b145587d6062e2335f0b3286501dd6853a1ea50bd466a913351b7c48e5f20"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ccd37165d7a3e93f460096a2eb62b7a9c1ebe5c424eaee42d8e92740d0c8f6bc"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:046c0b1e372d4acf552aa0c8f5e830f019d67b75f25aeb0968d15fbdd3eaabd3"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:531586c5598a99658249f3c5e92826d6d2bb117abd6ffc88527d1e1d9eaef924"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc90b0287908c46281eb61933acaa1b96a575d0160fc98b5c64b9dec46f60d1"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-win32.whl", hash = "sha256:392ad4cd004f7b843cf7d916d9a15b2d6585965bfef235be1c88d8f8649777e5"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-win32.whl", hash = "sha256:8e9df40db7a0edd403b539cc142d6114270e35debf723a5b4a7a93d5c30fffc0"},
|
||||||
{file = "grpcio_tools-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:14e82c2b3ee7e300611c2c729d411b3b911e4cca5f4ec14787457a2fb72ff9d4"},
|
{file = "grpcio_tools-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:077adaee431c2b040dd77923964577087c32e828908e8fa2e53f8e003ad408c9"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:2281180490c475d09b7aa05dabafa5e09de9902176931e7295113f636c2b5360"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:b50f9b8a6482a90c1a41e731a879a130f7dea267065d0a06f47c9160ce5d01c3"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:c4649af7f5d9553975ee66b6bfae20a84be779f13e163fa835e782961895e63c"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:89a68adcb4238aba69f3a364ac02c9a46e55b9e3fd8af1c6f384079abfa9347c"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f06bb0753b7cecbff154b523cfb8f45dee2c31b0a4c72bed7da44c57f1cba113"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d177da43e7f6fde6715df4a3015ae13158166bc2845ac7f9cfb526eafb41b8"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a671466158ed74c07ee070fb940ed783acf59ba6e6e53cb4de8fd63819c6c7f"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:793f9edef82f600a3324f8a3d8cd8318a8d02f28fb54f8236cbb35ce0928d186"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:048793747339f327ea091d8f022c6756d89713d8080dffde5ce7380cc348ea8e"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f7583735542ced7d30baec6cc21bffeaffcec1523bf807e8f8f0047113b6d30a"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f6caf36e7752728329a28f93afec7c4ec9015fc1c6e4460bd1eb0f3737e1c55a"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f2df233a3e7db23d9b516cb5e2bfe029465f40a72978bee0584e44e7860ea73f"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-win32.whl", hash = "sha256:67b304282cad38642587ebae68617e450e1ad4fa1c0c8b19e9e30274dbb32716"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-win32.whl", hash = "sha256:7427939455735fbf2ea88c37f1585c9c8b809eec7b447642f34465eb4d26020b"},
|
||||||
{file = "grpcio_tools-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:674b340f2f7bb2adbc3f15144bd37ce5ea83239f78b68dbbd0ea3cba00107e2b"},
|
{file = "grpcio_tools-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:ba76d15fd149b575170fa32a1f6a9ff2b38ff9db223229a8ad6f53450a452688"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:055819992ddd30c642a7fd6f344a03747be3afa95cb910f8a2e5efaabd41cde5"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d2212c682529263b3c9e903092d0ccbb9fc6afba820e4c2fa52c2c27720cdcae"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:4e3249a2ec435b3b972610c66c8a714c188844500d564c910f57a2771dc61978"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:405656b3cf9639427e6c30a795570cba4a7c06b88a3145866f7d2c05b7e048b4"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:794f26a09b70f4f101df5cf54c6c12dc1b65747ab1dee5bda02c2991389ade56"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:3c445a064b2ef3d3475e26e2add8ddb4ac2933741ecddf71d5b071a3ad078db4"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4957f1ffa16598aa5379505fcbaeb47d65693a46b0817f4ee61db76707092aeb"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7b3374f4a6579c58d16a5fab2e6b4e9bb8625a034a7f4cd6024f4d1cc12f2a0"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9906fb6bf6d9c30c23d85153f12d130f44325afe8f9ebe58aa7a6c82ecade9d8"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e8df08b65f9379c3f103147b29542b0141ca84e77d0eee9114ca5f9b3f0d23"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87bc5f3e3698c65907d397003c64d25c3ea84e3d6aa46dac133bd98bf66835ee"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2fade12de08923b350475ca16d0d0bd68578c30fce89147aa0f94ef5759bc5a9"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a66b3a5d18a7615f0f828b72e2d2935751459c89cc4725e56bdfb3d2cd93281f"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d4ffb6325ed489065dbdca764cf37c3a29376bc657874116c9af788d7a0d2ee4"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:566809d9942e78821b279af70f3cf159a328127f9f3d5fee8d83ad8b2d27b2fe"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:f8d17271fc58ed3503dd571c79917e126deca51f85f093770a9606e806aac9dc"},
|
||||||
{file = "grpcio_tools-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:aab24a342642329de38139cb26f8492882ca0d8551bb87f6530bcc613945a0d0"},
|
{file = "grpcio_tools-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ef849687c7f2bd7f3277edc7c7cafc7042823d0fb078e3c01c861eb0c96ed181"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:6b83d7fc2597c6d392c225177d1fbbcff74900f8cc40b33236987fd1ff841330"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:7fd18d8d211fbfd337fc12e5bdd57e62368f636addf901d290e68a39f1dfea38"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:79c06d2577cb4d977922bbf01234de3b20f73d1784d3cbe3179deee1bdb9a60b"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:233fc56f054424232e2086f444004413e33c699174ce6ee0e279c25227243fec"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e9abc03d67793b1bf33dc766caa69a3333f9db029869ba6e8fc6cd9c251c0080"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:867fa1973fa8b0772077c15425f122f672a18b1c53709a8a2bff9d056db4c20e"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64d8ad369417759f5fdb8ffb7cbd6374fecc06ab51c9a226dee9bbd7d311c3b5"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b486a99bdf2722e68a9d59769389e2fb86878b6f293be5111f7678e364a0c359"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de51a0a71845b854f6a5967756c893c96bd03e37f39e5dce87b4f409dac36ee2"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8bbf412c357999f88d87f421fd48b4b114fc037fec7bbaed0cb7620c24a5e44"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9dfe6c12b0e2c07f6a4a91a9912ef4e5bd007672533891a44e6f433ffbf7c3b1"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1166744c40821bb0aa605d2af2287fac367756f858a3d18f4c3d25bc0b92757b"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:27113b354f7587684eb55125733e6e5be1f489458abfe12344dabd918d8dcc54"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:781896c488e07b9463196045e6725e52d018cd7d0e1062d4ab1eee2647ca9170"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-win32.whl", hash = "sha256:98777b5031f1b3c58b688815ffa83435c103b2152c26eb144f80f4a4bb34addb"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-win32.whl", hash = "sha256:35c1ee7c766eb586f04ba41fa7711eb847767eb277a1737998374ac57768f1f0"},
|
||||||
{file = "grpcio_tools-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:1c44b57a6770b78a1eafe355878ff1ec59a2fa07455a2cbd522c071eedae04d4"},
|
{file = "grpcio_tools-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:584b201fb39307dcb1affcf2647656a0e6244423ef1659cc6caa3ff85c5ae5c1"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:49624394805568acd7d767dea5a00d970fca5ad8f395fe0161eeea0de5133eba"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:e02231e21029f716a1d23a0b5e664fa243d147da33a3f55088a9529b860aa4ac"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:6d6626a6e4dbe843df96dc8c08dd244d2191a75324f54bfa4ebaa3e76b0b1958"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:fbb742e10bd548031b8d80f7c28eb70c7c3a9850f8e99c98cd496f19a05f9fee"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:b4fb8ed6d29f2d6cf03ef99ffaad635bbc132a59be77013691392fe557e67144"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a836a72c657f751244cdb358c3461a89627e6d02654079d2450cfe361800428c"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8cc862a1ad30f94528d66cc6f95fb9e659005e568313e54a23550535b649573"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb554408e0ec5ff5201013f268726d9eef8e5bd1fd4b4e09c46c0b4a9de8b64c"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e72a30be1746ea0749a8486d0ca0120c0b2757fe84fc246a5144b1ef66d7b89"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:158c5bfe7e157fd9a944bde9f7dfe3b468416666e4fade77cd17caa3edc8bd81"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:331a897306adeec3c67470431ea8d8b4972b689d32966f94506d91f4dac20952"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:715c792679224171c0584e9f235b921d76f8990deb38b0d1215d0469301d9cd9"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f336ad9be661d92fa45940e74e8ff3d78e67ebe9b4f7ea8774b2d680c17aeb6c"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ece44f42b10e0bceb49235be1e361e1ee69afee7f938c82fb656a601a4a720e3"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-win32.whl", hash = "sha256:980e632710ba05e04364c6f276e905d5d367437f1ce2265ce7b96b5c1eac5693"},
|
||||||
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
|
{file = "grpcio_tools-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:5f4c47b14e66f80365cd5667ecc2f7fb0eb91e02c4e54362041b758feaa00511"},
|
||||||
|
]
|
||||||
|
hf-transfer = [
|
||||||
|
{file = "hf_transfer-0.1.2-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:2b9189a4a460646ee135ee771f39c0f695d3d5bf08b7ff1dcfe374227520e994"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:654fcaba4e7084caa1e97430982ea968935a72916ee0f4afc60e356f89774099"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp310-none-win_amd64.whl", hash = "sha256:eb29e7b3707b5cac02e689c89111685ebcdaa3cebba02eb7ac1b0f076357da72"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0bfca9bd84e925e978a0f157df488704c17a0b9ad240b2859262faba0c74cd40"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d00c5473b35227b2f113fd43ff13cbac9539f2e6779fa0680a887b0aac31c389"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp311-none-win_amd64.whl", hash = "sha256:1aaf5937aa433b7d09ce5bf60967ec22b7d3982957b00516a8dc2aaa66384372"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:b0aa760a55995ad59ea17e395babafdc56c4e664be0c2d2055664199dd913da1"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:889dd15e8472daf66e266eb056e31a485af3c35f95a483bb43489a0f6e44c359"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp37-none-win_amd64.whl", hash = "sha256:30df586e18ec8a8e67e3201b9038210d94cb3c03c1cbd97673b9c78ede227178"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc97eb97f929f96bed896cd3af9bbdf121c15ac6d63524b9fc9312fd2929099a"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:583c2c80210a60dafed9a81ba50c389878aee6c34b2dd375cd84522658f29ad8"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp38-none-win_amd64.whl", hash = "sha256:6dff58f50d1435b0346f31a32f1f9e2301986521c1d0b51e47a3c82b96d02156"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d6db1a8f539133f7a893bb32721916fe72b4d2aa3eb7604581ba1f03b8167c90"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f284e3f775d215c9a8d3d1c6f6b1001b1e7990d73ae5fd9aea6c9bce9ea79285"},
|
||||||
|
{file = "hf_transfer-0.1.2-cp39-none-win_amd64.whl", hash = "sha256:8625beabebc582eafc4141a5ecb9f1183b728d4f63767f01fdcf1e2fbafe6d43"},
|
||||||
|
{file = "hf_transfer-0.1.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947dd1b8b22ac10723b2887ed4b5ef929f7d4dd850b0d66c0c6954a9a85afb06"},
|
||||||
|
{file = "hf_transfer-0.1.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90a020f41dfae4629186c284888cd5adbebe402e2497a88351416ab93c7df9a8"},
|
||||||
|
{file = "hf_transfer-0.1.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5eb89698746a29805bfc60126b9a008e6ba08a82ef9bb122a6544e84f748e8a4"},
|
||||||
|
{file = "hf_transfer-0.1.2.tar.gz", hash = "sha256:6bf847f4c19c7d8d9f9bbb8a7ed52e1271bbf0c1bd920357db0c274ccc69f21d"},
|
||||||
]
|
]
|
||||||
idna = [
|
idna = [
|
||||||
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
|
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
|
||||||
@ -965,20 +994,19 @@ pluggy = [
|
|||||||
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
|
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
|
||||||
]
|
]
|
||||||
protobuf = [
|
protobuf = [
|
||||||
{file = "protobuf-4.21.12-cp310-abi3-win32.whl", hash = "sha256:b135410244ebe777db80298297a97fbb4c862c881b4403b71bac9d4107d61fd1"},
|
{file = "protobuf-4.22.0-cp310-abi3-win32.whl", hash = "sha256:b2fea9dc8e3c0f32c38124790ef16cba2ee0628fe2022a52e435e1117bfef9b1"},
|
||||||
{file = "protobuf-4.21.12-cp310-abi3-win_amd64.whl", hash = "sha256:89f9149e4a0169cddfc44c74f230d7743002e3aa0b9472d8c28f0388102fc4c2"},
|
{file = "protobuf-4.22.0-cp310-abi3-win_amd64.whl", hash = "sha256:a33a273d21852f911b8bda47f39f4383fe7c061eb1814db2c76c9875c89c2491"},
|
||||||
{file = "protobuf-4.21.12-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:299ea899484ee6f44604deb71f424234f654606b983cb496ea2a53e3c63ab791"},
|
{file = "protobuf-4.22.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:e894e9ae603e963f0842498c4cd5d39c6a60f0d7e4c103df50ee939564298658"},
|
||||||
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:d1736130bce8cf131ac7957fa26880ca19227d4ad68b4888b3be0dea1f95df97"},
|
{file = "protobuf-4.22.0-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:7c535d126e7dcc714105ab20b418c4fedbd28f8b8afc42b7350b1e317bbbcc71"},
|
||||||
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:78a28c9fa223998472886c77042e9b9afb6fe4242bd2a2a5aced88e3f4422aa7"},
|
{file = "protobuf-4.22.0-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:86c3d20428b007537ba6792b475c0853bba7f66b1f60e610d913b77d94b486e4"},
|
||||||
{file = "protobuf-4.21.12-cp37-cp37m-win32.whl", hash = "sha256:3d164928ff0727d97022957c2b849250ca0e64777ee31efd7d6de2e07c494717"},
|
{file = "protobuf-4.22.0-cp37-cp37m-win32.whl", hash = "sha256:1669cb7524221a8e2d9008d0842453dbefdd0fcdd64d67672f657244867635fb"},
|
||||||
{file = "protobuf-4.21.12-cp37-cp37m-win_amd64.whl", hash = "sha256:f45460f9ee70a0ec1b6694c6e4e348ad2019275680bd68a1d9314b8c7e01e574"},
|
{file = "protobuf-4.22.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ab4d043865dd04e6b09386981fe8f80b39a1e46139fb4a3c206229d6b9f36ff6"},
|
||||||
{file = "protobuf-4.21.12-cp38-cp38-win32.whl", hash = "sha256:6ab80df09e3208f742c98443b6166bcb70d65f52cfeb67357d52032ea1ae9bec"},
|
{file = "protobuf-4.22.0-cp38-cp38-win32.whl", hash = "sha256:29288813aacaa302afa2381db1d6e0482165737b0afdf2811df5fa99185c457b"},
|
||||||
{file = "protobuf-4.21.12-cp38-cp38-win_amd64.whl", hash = "sha256:1f22ac0ca65bb70a876060d96d914dae09ac98d114294f77584b0d2644fa9c30"},
|
{file = "protobuf-4.22.0-cp38-cp38-win_amd64.whl", hash = "sha256:e474b63bab0a2ea32a7b26a4d8eec59e33e709321e5e16fb66e766b61b82a95e"},
|
||||||
{file = "protobuf-4.21.12-cp39-cp39-win32.whl", hash = "sha256:27f4d15021da6d2b706ddc3860fac0a5ddaba34ab679dc182b60a8bb4e1121cc"},
|
{file = "protobuf-4.22.0-cp39-cp39-win32.whl", hash = "sha256:47d31bdf58222dd296976aa1646c68c6ee80b96d22e0a3c336c9174e253fd35e"},
|
||||||
{file = "protobuf-4.21.12-cp39-cp39-win_amd64.whl", hash = "sha256:237216c3326d46808a9f7c26fd1bd4b20015fb6867dc5d263a493ef9a539293b"},
|
{file = "protobuf-4.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:c27f371f0159feb70e6ea52ed7e768b3f3a4c5676c1900a7e51a24740381650e"},
|
||||||
{file = "protobuf-4.21.12-py2.py3-none-any.whl", hash = "sha256:a53fd3f03e578553623272dc46ac2f189de23862e68565e83dde203d41b76fc5"},
|
{file = "protobuf-4.22.0-py3-none-any.whl", hash = "sha256:c3325803095fb4c2a48649c321d2fbde59f8fbfcb9bfc7a86df27d112831c571"},
|
||||||
{file = "protobuf-4.21.12-py3-none-any.whl", hash = "sha256:b98d0148f84e3a3c569e19f52103ca1feacdac0d2df8d6533cf983d1fda28462"},
|
{file = "protobuf-4.22.0.tar.gz", hash = "sha256:652d8dfece122a24d98eebfef30e31e455d300efa41999d1182e015984ac5930"},
|
||||||
{file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
|
|
||||||
]
|
]
|
||||||
psutil = [
|
psutil = [
|
||||||
{file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"},
|
{file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"},
|
||||||
@ -1089,8 +1117,8 @@ safetensors = [
|
|||||||
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
|
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
|
||||||
]
|
]
|
||||||
setuptools = [
|
setuptools = [
|
||||||
{file = "setuptools-67.2.0-py3-none-any.whl", hash = "sha256:16ccf598aab3b506593c17378473978908a2734d7336755a8769b480906bec1c"},
|
{file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
|
||||||
{file = "setuptools-67.2.0.tar.gz", hash = "sha256:b440ee5f7e607bb8c9de15259dba2583dd41a38879a7abc1d43a71c59524da48"},
|
{file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
|
||||||
]
|
]
|
||||||
tomli = [
|
tomli = [
|
||||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||||
@ -1124,8 +1152,8 @@ typer = [
|
|||||||
{file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"},
|
{file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"},
|
||||||
]
|
]
|
||||||
typing-extensions = [
|
typing-extensions = [
|
||||||
{file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"},
|
{file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"},
|
||||||
{file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"},
|
{file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"},
|
||||||
]
|
]
|
||||||
urllib3 = [
|
urllib3 = [
|
||||||
{file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"},
|
{file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"},
|
||||||
@ -1140,68 +1168,79 @@ win32-setctime = [
|
|||||||
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
||||||
]
|
]
|
||||||
wrapt = [
|
wrapt = [
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"},
|
{file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"},
|
{file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28"},
|
{file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59"},
|
{file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:96e25c8603a155559231c19c0349245eeb4ac0096fe3c1d0be5c47e075bd4f46"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87"},
|
{file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:40737a081d7497efea35ab9304b829b857f21558acfc7b3272f908d33b0d9d4c"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1"},
|
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:f87ec75864c37c4c6cb908d282e1969e79763e0d9becdfe9fe5473b7bb1e5f09"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b"},
|
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1286eb30261894e4c70d124d44b7fd07825340869945c79d05bda53a40caa079"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462"},
|
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:493d389a2b63c88ad56cdc35d0fa5752daac56ca755805b1b0c530f785767d5e"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1"},
|
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:58d7a75d731e8c63614222bcb21dd992b4ab01a399f1f09dd82af17bbfc2368a"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320"},
|
{file = "wrapt-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:21f6d9a0d5b3a207cdf7acf8e58d7d13d463e639f0c7e01d82cdb671e6cb7923"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2"},
|
{file = "wrapt-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce42618f67741d4697684e501ef02f29e758a123aa2d669e2d964ff734ee00ee"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4"},
|
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41d07d029dd4157ae27beab04d22b8e261eddfc6ecd64ff7000b10dc8b3a5727"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069"},
|
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54accd4b8bc202966bafafd16e69da9d5640ff92389d33d28555c5fd4f25ccb7"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310"},
|
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbfbca668dd15b744418265a9607baa970c347eefd0db6a518aaf0cfbd153c0"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f"},
|
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:76e9c727a874b4856d11a32fb0b389afc61ce8aaf281ada613713ddeadd1cfec"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656"},
|
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e20076a211cd6f9b44a6be58f7eeafa7ab5720eb796975d0c03f05b47d89eb90"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
|
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a74d56552ddbde46c246b5b89199cb3fd182f9c346c784e1a93e4dc3f5ec9975"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
|
{file = "wrapt-1.15.0-cp310-cp310-win32.whl", hash = "sha256:26458da5653aa5b3d8dc8b24192f574a58984c749401f98fff994d41d3f08da1"},
|
||||||
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
|
{file = "wrapt-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:75760a47c06b5974aa5e01949bf7e66d2af4d08cb8c1d6516af5e39595397f5e"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
|
{file = "wrapt-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba1711cda2d30634a7e452fc79eabcadaffedf241ff206db2ee93dd2c89a60e7"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
|
{file = "wrapt-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56374914b132c702aa9aa9959c550004b8847148f95e1b824772d453ac204a72"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},
|
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a89ce3fd220ff144bd9d54da333ec0de0399b52c9ac3d2ce34b569cf1a5748fb"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d"},
|
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bbe623731d03b186b3d6b0d6f51865bf598587c38d6f7b0be2e27414f7f214e"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-win32.whl", hash = "sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7"},
|
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3abbe948c3cbde2689370a262a8d04e32ec2dd4f27103669a45c6929bcdbfe7c"},
|
||||||
{file = "wrapt-1.14.1-cp35-cp35m-win_amd64.whl", hash = "sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00"},
|
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b67b819628e3b748fd3c2192c15fb951f549d0f47c0449af0764d7647302fda3"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4"},
|
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7eebcdbe3677e58dd4c0e03b4f2cfa346ed4049687d839adad68cc38bb559c92"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1"},
|
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74934ebd71950e3db69960a7da29204f89624dde411afbfb3b4858c1409b1e98"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1"},
|
{file = "wrapt-1.15.0-cp311-cp311-win32.whl", hash = "sha256:bd84395aab8e4d36263cd1b9308cd504f6cf713b7d6d3ce25ea55670baec5416"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff"},
|
{file = "wrapt-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:a487f72a25904e2b4bbc0817ce7a8de94363bd7e79890510174da9d901c38705"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d"},
|
{file = "wrapt-1.15.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:4ff0d20f2e670800d3ed2b220d40984162089a6e2c9646fdb09b85e6f9a8fc29"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1"},
|
{file = "wrapt-1.15.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9ed6aa0726b9b60911f4aed8ec5b8dd7bf3491476015819f56473ffaef8959bd"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569"},
|
{file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:896689fddba4f23ef7c718279e42f8834041a21342d95e56922e1c10c0cc7afb"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-win32.whl", hash = "sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed"},
|
{file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:75669d77bb2c071333417617a235324a1618dba66f82a750362eccbe5b61d248"},
|
||||||
{file = "wrapt-1.14.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471"},
|
{file = "wrapt-1.15.0-cp35-cp35m-win32.whl", hash = "sha256:fbec11614dba0424ca72f4e8ba3c420dba07b4a7c206c8c8e4e73f2e98f4c559"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248"},
|
{file = "wrapt-1.15.0-cp35-cp35m-win_amd64.whl", hash = "sha256:fd69666217b62fa5d7c6aa88e507493a34dec4fa20c5bd925e4bc12fce586639"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68"},
|
{file = "wrapt-1.15.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0724f05c396b0a4c36a3226c31648385deb6a65d8992644c12a4963c70326ba"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d"},
|
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbeccb1aa40ab88cd29e6c7d8585582c99548f55f9b2581dfc5ba68c59a85752"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77"},
|
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38adf7198f8f154502883242f9fe7333ab05a5b02de7d83aa2d88ea621f13364"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7"},
|
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:578383d740457fa790fdf85e6d346fda1416a40549fe8db08e5e9bd281c6a475"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015"},
|
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:a4cbb9ff5795cd66f0066bdf5947f170f5d63a9274f99bdbca02fd973adcf2a8"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a"},
|
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:af5bd9ccb188f6a5fdda9f1f09d9f4c86cc8a539bd48a0bfdc97723970348418"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853"},
|
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:b56d5519e470d3f2fe4aa7585f0632b060d532d0696c5bdfb5e8319e1d0f69a2"},
|
||||||
{file = "wrapt-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c"},
|
{file = "wrapt-1.15.0-cp36-cp36m-win32.whl", hash = "sha256:77d4c1b881076c3ba173484dfa53d3582c1c8ff1f914c6461ab70c8428b796c1"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456"},
|
{file = "wrapt-1.15.0-cp36-cp36m-win_amd64.whl", hash = "sha256:077ff0d1f9d9e4ce6476c1a924a3332452c1406e59d90a2cf24aeb29eeac9420"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f"},
|
{file = "wrapt-1.15.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5c5aa28df055697d7c37d2099a7bc09f559d5053c3349b1ad0c39000e611d317"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc"},
|
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a8564f283394634a7a7054b7983e47dbf39c07712d7b177b37e03f2467a024e"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1"},
|
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780c82a41dc493b62fc5884fb1d3a3b81106642c5c5c78d6a0d4cbe96d62ba7e"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af"},
|
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e169e957c33576f47e21864cf3fc9ff47c223a4ebca8960079b8bd36cb014fd0"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b"},
|
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b02f21c1e2074943312d03d243ac4388319f2456576b2c6023041c4d57cd7019"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0"},
|
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f2e69b3ed24544b0d3dbe2c5c0ba5153ce50dcebb576fdc4696d52aa22db6034"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57"},
|
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d787272ed958a05b2c86311d3a4135d3c2aeea4fc655705f074130aa57d71653"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5"},
|
{file = "wrapt-1.15.0-cp37-cp37m-win32.whl", hash = "sha256:02fce1852f755f44f95af51f69d22e45080102e9d00258053b79367d07af39c0"},
|
||||||
{file = "wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d"},
|
{file = "wrapt-1.15.0-cp37-cp37m-win_amd64.whl", hash = "sha256:abd52a09d03adf9c763d706df707c343293d5d106aea53483e0ec8d9e310ad5e"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383"},
|
{file = "wrapt-1.15.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdb4f085756c96a3af04e6eca7f08b1345e94b53af8921b25c72f096e704e145"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7"},
|
{file = "wrapt-1.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:230ae493696a371f1dbffaad3dafbb742a4d27a0afd2b1aecebe52b740167e7f"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86"},
|
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63424c681923b9f3bfbc5e3205aafe790904053d42ddcc08542181a30a7a51bd"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735"},
|
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6bcbfc99f55655c3d93feb7ef3800bd5bbe963a755687cbf1f490a71fb7794b"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b"},
|
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99f4309f5145b93eca6e35ac1a988f0dc0a7ccf9ccdcd78d3c0adf57224e62f"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3"},
|
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b130fe77361d6771ecf5a219d8e0817d61b236b7d8b37cc045172e574ed219e6"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3"},
|
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:96177eb5645b1c6985f5c11d03fc2dbda9ad24ec0f3a46dcce91445747e15094"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe"},
|
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5fe3e099cf07d0fb5a1e23d399e5d4d1ca3e6dfcbe5c8570ccff3e9208274f7"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5"},
|
{file = "wrapt-1.15.0-cp38-cp38-win32.whl", hash = "sha256:abd8f36c99512755b8456047b7be10372fca271bf1467a1caa88db991e7c421b"},
|
||||||
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"},
|
{file = "wrapt-1.15.0-cp38-cp38-win_amd64.whl", hash = "sha256:b06fa97478a5f478fb05e1980980a7cdf2712015493b44d0c87606c1513ed5b1"},
|
||||||
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"},
|
{file = "wrapt-1.15.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2e51de54d4fb8fb50d6ee8327f9828306a959ae394d3e01a1ba8b2f937747d86"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0970ddb69bba00670e58955f8019bec4a42d1785db3faa043c33d81de2bf843c"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76407ab327158c510f44ded207e2f76b657303e17cb7a572ffe2f5a8a48aa04d"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd525e0e52a5ff16653a3fc9e3dd827981917d34996600bbc34c05d048ca35cc"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d37ac69edc5614b90516807de32d08cb8e7b12260a285ee330955604ed9dd29"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:078e2a1a86544e644a68422f881c48b84fef6d18f8c7a957ffd3f2e0a74a0d4a"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2cf56d0e237280baed46f0b5316661da892565ff58309d4d2ed7dba763d984b8"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7dc0713bf81287a00516ef43137273b23ee414fe41a3c14be10dd95ed98a2df9"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-win32.whl", hash = "sha256:46ed616d5fb42f98630ed70c3529541408166c22cdfd4540b88d5f21006b0eff"},
|
||||||
|
{file = "wrapt-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:eef4d64c650f33347c1f9266fa5ae001440b232ad9b98f1f43dfe7a79435c0a6"},
|
||||||
|
{file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"},
|
||||||
|
{file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"},
|
||||||
]
|
]
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation-server"
|
||||||
version = "0.2.1"
|
version = "0.4.0"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
text-generation-server = 'text_generation.cli:app'
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
@ -22,6 +22,7 @@ loguru = "^0.6.0"
|
|||||||
opentelemetry-api = "^1.15.0"
|
opentelemetry-api = "^1.15.0"
|
||||||
opentelemetry-exporter-otlp = "^1.15.0"
|
opentelemetry-exporter-otlp = "^1.15.0"
|
||||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||||
|
hf-transfer = "^0.1.2"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -10,6 +10,7 @@ def default_pb_parameters():
|
|||||||
repetition_penalty=1.0,
|
repetition_penalty=1.0,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
|
typical_p=1.0,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import torch
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
@ -65,8 +64,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
|||||||
assert batch.input_ids[0][-1] == 10264
|
assert batch.input_ids[0][-1] == 10264
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 3)
|
assert torch.all(batch.input_ids[0][:-1] == 3)
|
||||||
|
|
||||||
assert batch.attention_mask[0][-1] == 1
|
assert batch.attention_mask[0][0] == 1
|
||||||
assert torch.all(batch.attention_mask[0][:-1] == 0)
|
assert torch.all(batch.attention_mask[0][1:] == 0)
|
||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
|||||||
assert batch.size == default_pb_batch.size
|
assert batch.size == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||||
|
|
||||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
||||||
@ -98,22 +97,19 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
|||||||
assert not next_batch.keys_head_dim_last
|
assert not next_batch.keys_head_dim_last
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
assert (
|
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||||
len(next_batch.all_input_ids[0])
|
assert len(next_batch.attention_mask[0]) == 11
|
||||||
== len(next_batch.attention_mask[0])
|
|
||||||
== sequence_length + 1
|
|
||||||
)
|
|
||||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
|
assert torch.all(next_batch.attention_mask[0][:2] == 1)
|
||||||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||||
assert next_batch.input_ids[0, 0] == 10264
|
assert next_batch.input_ids[0, 0] == 10264
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
@ -213,15 +209,19 @@ def test_batch_concatenate(
|
|||||||
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
||||||
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask[0] == 1)
|
assert torch.all(
|
||||||
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
|
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
|
||||||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
)
|
||||||
|
assert torch.all(
|
||||||
|
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
|
||||||
|
)
|
||||||
|
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
|
||||||
|
|
||||||
assert next_batch.batch_id == 0
|
assert next_batch.batch_id == 0
|
||||||
assert torch.all(next_batch.input_ids == 10264)
|
assert torch.all(next_batch.input_ids == 10264)
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_sequence_length == 3
|
assert next_batch.max_input_length == 3
|
||||||
|
|
||||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||||
assert next_batch.requests[1:] == next_batch_1.requests
|
assert next_batch.requests[1:] == next_batch_1.requests
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
@ -62,8 +61,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
assert batch.input_ids[0][-1] == 14402
|
assert batch.input_ids[0][-1] == 14402
|
||||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
||||||
|
|
||||||
assert batch.attention_mask[0][-1] == 1
|
assert batch.attention_mask[0, 0] == 1
|
||||||
assert torch.all(batch.attention_mask[0][:-1] == 0)
|
assert torch.all(batch.attention_mask[0, 1:] == 0)
|
||||||
|
|
||||||
assert batch.past_key_values is None
|
assert batch.past_key_values is None
|
||||||
|
|
||||||
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
|||||||
assert batch.size == default_pb_batch.size
|
assert batch.size == default_pb_batch.size
|
||||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||||
|
|
||||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
assert batch.max_input_length == batch.input_lengths[0]
|
||||||
|
|
||||||
|
|
||||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||||
@ -94,23 +93,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
|||||||
assert isinstance(next_batch, CausalLMBatch)
|
assert isinstance(next_batch, CausalLMBatch)
|
||||||
|
|
||||||
assert len(next_batch.all_input_ids) == next_batch.size
|
assert len(next_batch.all_input_ids) == next_batch.size
|
||||||
assert (
|
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
|
||||||
len(next_batch.all_input_ids[0])
|
assert len(next_batch.attention_mask[0]) == 11
|
||||||
== len(next_batch.attention_mask[0])
|
|
||||||
== sequence_length + 1
|
|
||||||
)
|
|
||||||
assert next_batch.all_input_ids[0][-1] == 13
|
assert next_batch.all_input_ids[0][-1] == 13
|
||||||
assert next_batch.all_input_ids[0][-2] == 14402
|
assert next_batch.all_input_ids[0][-2] == 14402
|
||||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
|
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
|
||||||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
assert torch.all(next_batch.attention_mask[0][2:] == 0)
|
||||||
|
|
||||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||||
assert next_batch.input_ids[0, 0] == 13
|
assert next_batch.input_ids[0, 0] == 13
|
||||||
|
|
||||||
assert next_batch.input_lengths == [2]
|
assert next_batch.input_lengths == [2]
|
||||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
assert next_batch.max_input_length == next_batch.input_lengths[0]
|
||||||
|
|
||||||
assert next_batch.past_key_values is not None
|
assert next_batch.past_key_values is not None
|
||||||
assert all(
|
assert all(
|
||||||
@ -210,16 +206,20 @@ def test_batch_concatenate(
|
|||||||
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
||||||
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask[0] == 1)
|
assert torch.all(
|
||||||
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
|
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
|
||||||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
)
|
||||||
|
assert torch.all(
|
||||||
|
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
|
||||||
|
)
|
||||||
|
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
|
||||||
|
|
||||||
assert next_batch.batch_id == 0
|
assert next_batch.batch_id == 0
|
||||||
assert next_batch.input_ids[0, 0] == 12355
|
assert next_batch.input_ids[0, 0] == 12355
|
||||||
assert torch.all(next_batch.input_ids[1:] == 13)
|
assert torch.all(next_batch.input_ids[1:] == 13)
|
||||||
|
|
||||||
assert next_batch.input_lengths == [3, 2, 2]
|
assert next_batch.input_lengths == [3, 2, 2]
|
||||||
assert next_batch.max_sequence_length == 3
|
assert next_batch.max_input_length == 3
|
||||||
|
|
||||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||||
assert next_batch.requests[1:] == next_batch_1.requests
|
assert next_batch.requests[1:] == next_batch_1.requests
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
input_length=1,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
input_length=5,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
@ -5,8 +5,8 @@ from copy import copy
|
|||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
input_length=2,
|
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
@ -106,7 +105,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
assert len(generations) == len(next_batch)
|
assert len(generations) == len(next_batch)
|
||||||
assert isinstance(next_batch, Seq2SeqLMBatch)
|
assert isinstance(next_batch, Seq2SeqLMBatch)
|
||||||
|
|
||||||
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
assert next_batch.input_ids is None
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
|
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
|
||||||
)
|
)
|
||||||
@ -148,7 +147,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 259 for generation in generations])
|
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||||
assert all([generation.token_text == "" for generation in generations])
|
assert all([generation.token_text == " " for generation in generations])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
@ -220,11 +219,6 @@ def test_batch_concatenate(
|
|||||||
|
|
||||||
assert next_batch.batch_id == 0
|
assert next_batch.batch_id == 0
|
||||||
|
|
||||||
assert torch.all(next_batch.input_ids[:, 0] == 4268)
|
|
||||||
assert torch.all(next_batch.input_ids[:, 1] == 1)
|
|
||||||
|
|
||||||
assert torch.all(next_batch.attention_mask == 1)
|
|
||||||
|
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
||||||
)
|
)
|
||||||
@ -233,9 +227,10 @@ def test_batch_concatenate(
|
|||||||
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
|
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
|
||||||
|
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
|
||||||
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
|
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
|
||||||
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
|
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
|
||||||
|
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
next_batch.encoder_last_hidden_state[0],
|
next_batch.encoder_last_hidden_state[0],
|
||||||
|
21
server/tests/utils/test_convert.py
Normal file
21
server/tests/utils/test_convert.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from text_generation_server.utils.hub import (
|
||||||
|
download_weights,
|
||||||
|
weight_hub_files,
|
||||||
|
weight_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation_server.utils.convert import convert_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_files():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
||||||
|
local_pt_files = download_weights(pt_filenames, model_id)
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||||
|
]
|
||||||
|
convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
found_st_files = weight_files(model_id)
|
||||||
|
|
||||||
|
assert all([p in found_st_files for p in local_st_files])
|
40
server/tests/utils/test_hub.py
Normal file
40
server/tests/utils/test_hub.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from text_generation_server.utils.hub import (
|
||||||
|
weight_hub_files,
|
||||||
|
download_weights,
|
||||||
|
weight_files,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||||
|
assert filenames == ["model.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_llm():
|
||||||
|
filenames = weight_hub_files("bigscience/bloom")
|
||||||
|
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_empty():
|
||||||
|
with pytest.raises(EntryNotFoundError):
|
||||||
|
weight_hub_files("bigscience/bloom", extension=".errors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_weights():
|
||||||
|
model_id = "bigscience/bloom-560m"
|
||||||
|
filenames = weight_hub_files(model_id)
|
||||||
|
files = download_weights(filenames, model_id)
|
||||||
|
local_files = weight_files("bigscience/bloom-560m")
|
||||||
|
assert files == local_files
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_files_error():
|
||||||
|
with pytest.raises(RevisionNotFoundError):
|
||||||
|
weight_files("bigscience/bloom-560m", revision="error")
|
||||||
|
with pytest.raises(LocalEntryNotFoundError):
|
||||||
|
weight_files("bert-base-uncased")
|
@ -1,14 +1,6 @@
|
|||||||
import pytest
|
from text_generation_server.utils.tokens import (
|
||||||
|
|
||||||
from huggingface_hub.utils import RevisionNotFoundError
|
|
||||||
|
|
||||||
from text_generation.utils import (
|
|
||||||
weight_hub_files,
|
|
||||||
download_weights,
|
|
||||||
weight_files,
|
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
|
||||||
FinishReason,
|
FinishReason,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
|
||||||
assert filenames == ["model.safetensors"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_llm():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom")
|
|
||||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files_empty():
|
|
||||||
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
|
|
||||||
assert filenames == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_download_weights():
|
|
||||||
files = download_weights("bigscience/bloom-560m")
|
|
||||||
local_files = weight_files("bigscience/bloom-560m")
|
|
||||||
assert files == local_files
|
|
||||||
|
|
||||||
|
|
||||||
def test_weight_files_error():
|
|
||||||
with pytest.raises(RevisionNotFoundError):
|
|
||||||
weight_files("bigscience/bloom-560m", revision="error")
|
|
||||||
with pytest.raises(LocalEntryNotFoundError):
|
|
||||||
weight_files("bert-base-uncased")
|
|
@ -1,68 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from loguru import logger
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from text_generation import server, utils
|
|
||||||
from text_generation.tracing import setup_tracing
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def serve(
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
sharded: bool = False,
|
|
||||||
quantize: bool = False,
|
|
||||||
uds_path: Path = "/tmp/text-generation",
|
|
||||||
logger_level: str = "INFO",
|
|
||||||
json_output: bool = False,
|
|
||||||
otlp_endpoint: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if sharded:
|
|
||||||
assert (
|
|
||||||
os.getenv("RANK", None) is not None
|
|
||||||
), "RANK must be set when sharded is True"
|
|
||||||
assert (
|
|
||||||
os.getenv("WORLD_SIZE", None) is not None
|
|
||||||
), "WORLD_SIZE must be set when sharded is True"
|
|
||||||
assert (
|
|
||||||
os.getenv("MASTER_ADDR", None) is not None
|
|
||||||
), "MASTER_ADDR must be set when sharded is True"
|
|
||||||
assert (
|
|
||||||
os.getenv("MASTER_PORT", None) is not None
|
|
||||||
), "MASTER_PORT must be set when sharded is True"
|
|
||||||
|
|
||||||
# Remove default handler
|
|
||||||
logger.remove()
|
|
||||||
logger.add(
|
|
||||||
sys.stdout,
|
|
||||||
format="{message}",
|
|
||||||
filter="text_generation",
|
|
||||||
level=logger_level,
|
|
||||||
serialize=json_output,
|
|
||||||
backtrace=True,
|
|
||||||
diagnose=False,
|
|
||||||
)
|
|
||||||
# Setup OpenTelemetry distributed tracing
|
|
||||||
if otlp_endpoint is not None:
|
|
||||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
|
||||||
|
|
||||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def download_weights(
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
extension: str = ".safetensors",
|
|
||||||
):
|
|
||||||
utils.download_weights(model_id, revision, extension)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
@ -1,24 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from text_generation.models.types import Batch, GeneratedText
|
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def batch_type(self) -> Type[B]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
|
||||||
raise NotImplementedError
|
|
@ -1,283 +0,0 @@
|
|||||||
import concurrent
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
from transformers.generation.logits_process import (
|
|
||||||
LogitsProcessorList,
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
TemperatureLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
|
||||||
self.generator = torch.Generator(device)
|
|
||||||
self.generator.manual_seed(seed)
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, logits):
|
|
||||||
probs = torch.nn.functional.softmax(logits)
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
|
||||||
def __call__(self, logits):
|
|
||||||
return logits.argmax()
|
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
temperature=1.0,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
top_k=None,
|
|
||||||
top_p=None,
|
|
||||||
do_sample=False,
|
|
||||||
seed=0,
|
|
||||||
device="cpu",
|
|
||||||
):
|
|
||||||
warpers = LogitsProcessorList()
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
|
||||||
sampling = do_sample
|
|
||||||
if temperature is not None and temperature != 1.0:
|
|
||||||
temperature = float(temperature)
|
|
||||||
warpers.append(TemperatureLogitsWarper(temperature))
|
|
||||||
sampling = True
|
|
||||||
if top_k is not None and top_k != 0:
|
|
||||||
warpers.append(TopKLogitsWarper(top_k=top_k))
|
|
||||||
sampling = True
|
|
||||||
if top_p is not None and top_p < 1.0:
|
|
||||||
warpers.append(TopPLogitsWarper(top_p=top_p))
|
|
||||||
sampling = True
|
|
||||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
|
||||||
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
|
||||||
|
|
||||||
self.warpers = warpers
|
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
|
||||||
# Warp logits
|
|
||||||
scores = self.warpers(input_ids, scores)
|
|
||||||
|
|
||||||
# Compute logprobs
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
|
||||||
|
|
||||||
# Choose tokens
|
|
||||||
next_id = self.choice(scores[-1])
|
|
||||||
|
|
||||||
return next_id.view(1, 1), logprobs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
|
|
||||||
) -> "NextTokenChooser":
|
|
||||||
return NextTokenChooser(
|
|
||||||
temperature=pb.temperature,
|
|
||||||
repetition_penalty=pb.repetition_penalty,
|
|
||||||
top_k=pb.top_k,
|
|
||||||
top_p=pb.top_p,
|
|
||||||
do_sample=pb.do_sample,
|
|
||||||
seed=pb.seed,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StopSequenceCriteria:
|
|
||||||
def __init__(self, stop_sequence: str):
|
|
||||||
self.regex = re.compile(f".*{stop_sequence}$")
|
|
||||||
|
|
||||||
def __call__(self, output: str) -> bool:
|
|
||||||
if self.regex.findall(output):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StoppingCriteria:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
eos_token_id: int,
|
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
|
||||||
max_new_tokens=20,
|
|
||||||
):
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
|
||||||
self.max_new_tokens = max_new_tokens
|
|
||||||
self.current_tokens = 0
|
|
||||||
self.current_output = ""
|
|
||||||
|
|
||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
|
||||||
self.current_tokens += 1
|
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
|
||||||
|
|
||||||
self.current_output += last_output
|
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
|
||||||
if stop_sequence_criteria(self.current_output):
|
|
||||||
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
|
||||||
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.StoppingCriteriaParameters,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
) -> "StoppingCriteria":
|
|
||||||
stop_sequence_criterias = [
|
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
|
||||||
]
|
|
||||||
return StoppingCriteria(
|
|
||||||
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_torch_distributed():
|
|
||||||
rank = int(os.getenv("RANK", "0"))
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
from torch.distributed import ProcessGroupNCCL
|
|
||||||
|
|
||||||
# Set the device id.
|
|
||||||
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
|
||||||
device = rank % torch.cuda.device_count()
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
backend = "nccl"
|
|
||||||
options = ProcessGroupNCCL.Options()
|
|
||||||
options.is_high_priority_stream = True
|
|
||||||
options._timeout = timedelta(seconds=60)
|
|
||||||
else:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
|
||||||
|
|
||||||
# Call the init process.
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend=backend,
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
timeout=timedelta(seconds=60),
|
|
||||||
pg_options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch.distributed.group.WORLD, rank, world_size
|
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Get the safetensors filenames on the hub"""
|
|
||||||
api = HfApi()
|
|
||||||
info = api.model_info(model_id, revision=revision)
|
|
||||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(model_id, revision, filename):
|
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
|
||||||
if revision is None:
|
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
object_id = model_id.replace("/", "--")
|
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
|
||||||
# No cache for this model
|
|
||||||
return None
|
|
||||||
|
|
||||||
refs_dir = repo_cache / "refs"
|
|
||||||
snapshots_dir = repo_cache / "snapshots"
|
|
||||||
no_exist_dir = repo_cache / ".no_exist"
|
|
||||||
|
|
||||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
|
||||||
if refs_dir.is_dir():
|
|
||||||
revision_file = refs_dir / revision
|
|
||||||
if revision_file.exists():
|
|
||||||
with revision_file.open() as f:
|
|
||||||
revision = f.read()
|
|
||||||
|
|
||||||
# Check if file is cached as "no_exist"
|
|
||||||
if (no_exist_dir / revision / filename).is_file():
|
|
||||||
return _CACHED_NO_EXIST
|
|
||||||
|
|
||||||
# Check if revision folder exists
|
|
||||||
if not snapshots_dir.exists():
|
|
||||||
return None
|
|
||||||
cached_shas = os.listdir(snapshots_dir)
|
|
||||||
if revision not in cached_shas:
|
|
||||||
# No cache for this revision and we won't try to return a random revision
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if file exists in cache
|
|
||||||
cached_file = snapshots_dir / revision / filename
|
|
||||||
return str(cached_file) if cached_file.is_file() else None
|
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Get the local safetensors filenames"""
|
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
|
||||||
files = []
|
|
||||||
for filename in filenames:
|
|
||||||
cache_file = try_to_load_from_cache(
|
|
||||||
model_id, revision=revision, filename=filename
|
|
||||||
)
|
|
||||||
if cache_file is None:
|
|
||||||
raise LocalEntryNotFoundError(
|
|
||||||
f"File {filename} of model {model_id} not found in "
|
|
||||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
|
||||||
f"Please run `text-generation-server download-weights {model_id}` first."
|
|
||||||
)
|
|
||||||
files.append(cache_file)
|
|
||||||
|
|
||||||
return files
|
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_id, revision=None, extension=".safetensors"):
|
|
||||||
"""Download the safetensors files from the hub"""
|
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
|
||||||
|
|
||||||
download_function = partial(
|
|
||||||
hf_hub_download,
|
|
||||||
repo_id=model_id,
|
|
||||||
local_files_only=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=5)
|
|
||||||
futures = [
|
|
||||||
executor.submit(download_function, filename=filename, revision=revision)
|
|
||||||
for filename in filenames
|
|
||||||
]
|
|
||||||
files = [
|
|
||||||
future.result()
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
|
||||||
]
|
|
||||||
|
|
||||||
return files
|
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
115
server/text_generation_server/cli.py
Normal file
115
server/text_generation_server/cli.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server import server, utils
|
||||||
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def serve(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
sharded: bool = False,
|
||||||
|
quantize: bool = False,
|
||||||
|
uds_path: Path = "/tmp/text-generation",
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if sharded:
|
||||||
|
assert (
|
||||||
|
os.getenv("RANK", None) is not None
|
||||||
|
), "RANK must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("WORLD_SIZE", None) is not None
|
||||||
|
), "WORLD_SIZE must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_ADDR", None) is not None
|
||||||
|
), "MASTER_ADDR must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
# Setup OpenTelemetry distributed tracing
|
||||||
|
if otlp_endpoint is not None:
|
||||||
|
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
|
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def download_weights(
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
extension: str = ".safetensors",
|
||||||
|
logger_level: str = "INFO",
|
||||||
|
json_output: bool = False,
|
||||||
|
):
|
||||||
|
# Remove default handler
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="{message}",
|
||||||
|
filter="text_generation_server",
|
||||||
|
level=logger_level,
|
||||||
|
serialize=json_output,
|
||||||
|
backtrace=True,
|
||||||
|
diagnose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test if files were already download
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(
|
||||||
|
"Files are already present in the local cache. " "Skipping download."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except utils.LocalEntryNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Download weights directly
|
||||||
|
try:
|
||||||
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
utils.download_weights(filenames, model_id, revision)
|
||||||
|
except utils.EntryNotFoundError as e:
|
||||||
|
if not extension == ".safetensors":
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||||
|
f"Converting PyTorch weights instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to see if there are pytorch weights
|
||||||
|
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||||
|
# Download pytorch weights
|
||||||
|
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||||
|
local_st_files = [
|
||||||
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||||
|
for p in local_pt_files
|
||||||
|
]
|
||||||
|
# Convert pytorch weights to safetensors
|
||||||
|
utils.convert_files(local_pt_files, local_st_files)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
@ -3,14 +3,14 @@ import torch
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
from text_generation.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
||||||
@ -19,7 +19,6 @@ __all__ = [
|
|||||||
"CausalLM",
|
"CausalLM",
|
||||||
"Galactica",
|
"Galactica",
|
||||||
"GalacticaSharded",
|
"GalacticaSharded",
|
||||||
"GPTNeox",
|
|
||||||
"GPTNeoxSharded",
|
"GPTNeoxSharded",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"SantaCoder",
|
"SantaCoder",
|
||||||
@ -41,6 +40,15 @@ torch.set_grad_enabled(False)
|
|||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
if "facebook/galactica" in model_id:
|
||||||
|
if sharded:
|
||||||
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||||
|
else:
|
||||||
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
|
if "santacoder" in model_id:
|
||||||
|
return SantaCoder(model_id, revision, quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
|
||||||
if config.model_type == "bloom":
|
if config.model_type == "bloom":
|
||||||
@ -48,24 +56,19 @@ def get_model(
|
|||||||
return BLOOMSharded(model_id, revision, quantize=quantize)
|
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return BLOOM(model_id, revision, quantize=quantize)
|
return BLOOM(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "gpt_neox":
|
|
||||||
|
if config.model_type == "gpt_neox":
|
||||||
if sharded:
|
if sharded:
|
||||||
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return GPTNeox(model_id, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "t5":
|
|
||||||
|
if config.model_type == "t5":
|
||||||
if sharded:
|
if sharded:
|
||||||
return T5Sharded(model_id, revision, quantize=quantize)
|
return T5Sharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
elif model_id.startswith("facebook/galactica"):
|
|
||||||
if sharded:
|
|
||||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
|
||||||
else:
|
|
||||||
return Galactica(model_id, revision, quantize=quantize)
|
|
||||||
elif "santacoder" in model_id:
|
|
||||||
return SantaCoder(model_id, revision, quantize)
|
|
||||||
else:
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
try:
|
try:
|
@ -17,13 +17,12 @@ from transformers.models.bloom.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -59,9 +58,6 @@ class BLOOMSharded(BLOOM):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_id.startswith("bigscience/bloom"):
|
|
||||||
raise ValueError(f"Model {model_id} is not supported")
|
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -80,14 +76,8 @@ class BLOOMSharded(BLOOM):
|
|||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master and model_id == "bigscience/bloom-560m":
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
Batch,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
PrefillTokens,
|
||||||
|
Generation,
|
||||||
|
GeneratedText,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -36,7 +41,8 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
size: int
|
size: int
|
||||||
max_sequence_length: int
|
max_input_length: int
|
||||||
|
padding_right_offset: int
|
||||||
|
|
||||||
# Past metadata
|
# Past metadata
|
||||||
keys_head_dim_last: bool = True
|
keys_head_dim_last: bool = True
|
||||||
@ -61,22 +67,36 @@ class CausalLMBatch(Batch):
|
|||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
# Allocate maximum attention_mask
|
||||||
|
attention_mask = input_ids.new_zeros(
|
||||||
|
(pb.size, max_input_length + padding_right_offset)
|
||||||
|
)
|
||||||
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
@ -84,24 +104,30 @@ class CausalLMBatch(Batch):
|
|||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
input_ids=tokenized_inputs["input_ids"],
|
input_ids=input_ids,
|
||||||
attention_mask=tokenized_inputs["attention_mask"],
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths.tolist(),
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max(input_lengths),
|
max_input_length=max_input_length.item(),
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = sum(batch.size for batch in batches)
|
total_batch_size = 0
|
||||||
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
max_input_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
for batch in batches:
|
||||||
|
total_batch_size += batch.size
|
||||||
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
@ -144,13 +170,24 @@ class CausalLMBatch(Batch):
|
|||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = batch.attention_mask.new_zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_sequence_length),
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
|
# and to remove unused allocated space
|
||||||
|
left_offset = max_input_length - batch.max_input_length
|
||||||
|
batch_left_offset = (
|
||||||
|
batch.attention_mask.shape[1]
|
||||||
|
- batch.max_input_length
|
||||||
|
- batch.padding_right_offset
|
||||||
|
)
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_sequence_length :
|
start_index:end_index,
|
||||||
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
left_offset:-padding_right_offset,
|
||||||
|
] = batch.attention_mask[
|
||||||
|
:,
|
||||||
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
|
]
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
# position_ids is always of shape [batch_size, 1]
|
# position_ids is always of shape [batch_size, 1]
|
||||||
@ -172,7 +209,7 @@ class CausalLMBatch(Batch):
|
|||||||
padded_past_values_shape = (
|
padded_past_values_shape = (
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
max_sequence_length - 1,
|
max_input_length - 1,
|
||||||
head_dim,
|
head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -184,7 +221,7 @@ class CausalLMBatch(Batch):
|
|||||||
total_batch_size,
|
total_batch_size,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
max_sequence_length - 1,
|
max_input_length - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This will run only once per layer
|
# This will run only once per layer
|
||||||
@ -198,20 +235,20 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_input_length - 1) :,
|
||||||
:,
|
:,
|
||||||
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
|
||||||
else:
|
else:
|
||||||
past_key_values[j][0][
|
past_key_values[j][0][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_input_length - 1) :,
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
|
||||||
|
|
||||||
past_key_values[j][1][
|
past_key_values[j][1][
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
start_index:end_index, :, -(batch.max_input_length - 1) :, :
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
|
||||||
|
|
||||||
start_index += batch.size
|
start_index += batch.size
|
||||||
|
|
||||||
@ -227,7 +264,8 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_input_length=max_input_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -294,9 +332,12 @@ class CausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
|
# slice the attention mask to the correct shape
|
||||||
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
@ -311,7 +352,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_input_length = 0
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -347,10 +388,8 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(
|
next_token_text = self.decode_token(
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -381,8 +420,8 @@ class CausalLM(Model):
|
|||||||
next_batch_all_input_ids.append(all_input_ids)
|
next_batch_all_input_ids.append(all_input_ids)
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
next_batch_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_max_sequence_length = max(
|
next_batch_max_input_length = max(
|
||||||
next_batch_max_sequence_length, new_input_length
|
next_batch_max_input_length, new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
@ -409,6 +448,7 @@ class CausalLM(Model):
|
|||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,14 +488,8 @@ class CausalLM(Model):
|
|||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
next_batch_next_token_choosers = batch.next_token_choosers
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
# Update attention_mask with padding as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
next_batch_attention_mask = torch.cat(
|
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
[
|
|
||||||
next_batch_attention_mask,
|
|
||||||
next_batch_attention_mask.new_ones(next_batch_size, 1),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
|
||||||
@ -472,7 +506,8 @@ class CausalLM(Model):
|
|||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_sequence_length=next_batch_max_sequence_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
|
padding_right_offset=batch.padding_right_offset - 1,
|
||||||
keys_head_dim_last=batch.keys_head_dim_last,
|
keys_head_dim_last=batch.keys_head_dim_last,
|
||||||
)
|
)
|
||||||
return generations, next_batch
|
return generations, next_batch
|
@ -2,7 +2,7 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type, Tuple
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@ -18,15 +18,14 @@ from transformers.models.opt.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -97,24 +96,37 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
max_sequence_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
max_sequence_length = max(max_sequence_length, r.input_length)
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
# Allocate maximum attention_mask
|
||||||
|
attention_mask = input_ids.new_zeros(
|
||||||
|
(pb.size, max_sequence_length + padding_right_offset)
|
||||||
|
)
|
||||||
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
|
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
@ -122,8 +134,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
input_ids=tokenized_inputs["input_ids"],
|
input_ids=input_ids,
|
||||||
attention_mask=tokenized_inputs["attention_mask"],
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
@ -131,7 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max(input_lengths),
|
max_sequence_length=max_sequence_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -146,14 +159,25 @@ class Galactica(CausalLM):
|
|||||||
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Overwrite forward to ignore position_ids"""
|
||||||
|
|
||||||
|
# Model Forward
|
||||||
|
outputs = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_id.startswith("facebook/galactica"):
|
|
||||||
raise ValueError(f"Model {model_id} is not supported")
|
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -172,14 +196,8 @@ class GalacticaSharded(Galactica):
|
|||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master and model_id == "facebook/galactica-125m":
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
@ -329,7 +347,6 @@ class GalacticaSharded(Galactica):
|
|||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@ -16,11 +16,10 @@ from transformers.models.gpt_neox.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -31,23 +30,7 @@ except Exception as e:
|
|||||||
HAS_BITS_AND_BYTES = False
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class GPTNeox(CausalLM):
|
class GPTNeoxSharded(CausalLM):
|
||||||
def forward(
|
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
"""Overwrite forward to ignore position_ids"""
|
|
||||||
|
|
||||||
# Model Forward
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
return outputs.logits, outputs.past_key_values
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoxSharded(GPTNeox):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
@ -69,14 +52,8 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
model_id, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only master download weights
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
@ -231,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
|
|||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
43
server/text_generation_server/models/model.py
Normal file
43
server/text_generation_server/models/model.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
|
|
||||||
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(ABC):
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# see `decode_token` method
|
||||||
|
self.tokenizer.add_special_tokens(
|
||||||
|
{"additional_special_tokens": ["<decode-token>"]}
|
||||||
|
)
|
||||||
|
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
|
||||||
|
"<decode-token>"
|
||||||
|
)
|
||||||
|
self.special_decode_token_length = len("<decode-token>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def batch_type(self) -> Type[B]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode_token(self, token_id: int) -> str:
|
||||||
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
# append token to special decode token and decode both
|
||||||
|
result = self.tokenizer.decode(
|
||||||
|
[self.special_decode_token_id, token_id], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
# slice to remove special decode token
|
||||||
|
return result[self.special_decode_token_length :]
|
@ -1,10 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
FIM_MIDDLE = "<fim-middle>"
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
GeneratedText,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
Batch,
|
||||||
|
Generation,
|
||||||
|
PrefillTokens,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -42,9 +47,10 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
size: int
|
size: int
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_decoder_input_length: int
|
max_decoder_input_length: int
|
||||||
|
padding_right_offset: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
@ -58,36 +64,41 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
input_lengths = []
|
|
||||||
|
|
||||||
decoder_input_ids = []
|
decoder_input_ids = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
|
||||||
# Decoder sequence only contains the bos_token
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
stopping_criterias.append(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tokenize batch
|
# Tokenize batch
|
||||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
|
||||||
|
|
||||||
@ -100,13 +111,14 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
encoder_last_hidden_state=None,
|
encoder_last_hidden_state=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths.tolist(),
|
||||||
decoder_input_lengths=decoder_input_lengths,
|
decoder_input_lengths=decoder_input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=len(pb.requests),
|
size=len(pb.requests),
|
||||||
max_input_length=max(input_lengths),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -115,11 +127,17 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
"""Concatenate multiple batches together by padding internal torch tensors"""
|
"""Concatenate multiple batches together by padding internal torch tensors"""
|
||||||
|
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = sum(batch.size for batch in batches)
|
total_batch_size = 0
|
||||||
max_input_length = max(batch.max_input_length for batch in batches)
|
max_input_length = 0
|
||||||
|
max_decoder_input_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
for batch in batches:
|
||||||
|
total_batch_size += batch.size
|
||||||
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
max_decoder_input_length = max(
|
max_decoder_input_length = max(
|
||||||
batch.max_decoder_input_length for batch in batches
|
max_decoder_input_length, batch.max_decoder_input_length
|
||||||
)
|
)
|
||||||
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
@ -129,7 +147,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
input_ids = None
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
decoder_input_ids = None
|
decoder_input_ids = None
|
||||||
decoder_attention_mask = None
|
decoder_attention_mask = None
|
||||||
@ -155,16 +172,6 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
if batch.encoder_last_hidden_state is None:
|
if batch.encoder_last_hidden_state is None:
|
||||||
raise ValueError("Batch encoder_last_hidden_state cannot be None")
|
raise ValueError("Batch encoder_last_hidden_state cannot be None")
|
||||||
|
|
||||||
# Create padded tensor
|
|
||||||
if input_ids is None:
|
|
||||||
input_ids = batch.input_ids.new_zeros(
|
|
||||||
(total_batch_size, max_input_length),
|
|
||||||
)
|
|
||||||
# Copy to correct indices
|
|
||||||
input_ids[
|
|
||||||
start_index:end_index, -batch.max_input_length :
|
|
||||||
] = batch.input_ids[:, -batch.max_input_length :]
|
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = batch.attention_mask.new_zeros(
|
attention_mask = batch.attention_mask.new_zeros(
|
||||||
@ -189,19 +196,30 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
|
||||||
decoder_attention_mask = batch.attention_mask.new_zeros(
|
decoder_attention_mask = batch.attention_mask.new_zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, max_decoder_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||||
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
|
left_offset = max_decoder_input_length - batch.max_decoder_input_length
|
||||||
if batch.decoder_attention_mask is None:
|
if batch.decoder_attention_mask is None:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index,
|
||||||
|
left_offset:-padding_right_offset,
|
||||||
] = 1
|
] = 1
|
||||||
# If it exists, we need to index
|
# If it exists, we need to index
|
||||||
else:
|
else:
|
||||||
|
batch_left_offset = (
|
||||||
|
batch.decoder_attention_mask.shape[1]
|
||||||
|
- batch.max_decoder_input_length
|
||||||
|
- batch.padding_right_offset
|
||||||
|
)
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index,
|
||||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
left_offset:-padding_right_offset,
|
||||||
|
] = batch.decoder_attention_mask[
|
||||||
|
:,
|
||||||
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
|
]
|
||||||
|
|
||||||
# Create padded tensor
|
# Create padded tensor
|
||||||
if encoder_last_hidden_state is None:
|
if encoder_last_hidden_state is None:
|
||||||
@ -273,7 +291,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
input_ids=input_ids,
|
input_ids=None,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
@ -286,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -326,7 +345,9 @@ class Seq2SeqLM(Model):
|
|||||||
return Seq2SeqLMBatch
|
return Seq2SeqLMBatch
|
||||||
|
|
||||||
def decode(self, decoder_ids: List[int]) -> str:
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(
|
||||||
|
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -342,14 +363,6 @@ class Seq2SeqLM(Model):
|
|||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if past_key_values is not None:
|
|
||||||
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
|
||||||
# internally...
|
|
||||||
if encoder_last_hidden_state is not None:
|
|
||||||
encoder_last_hidden_state = [encoder_last_hidden_state]
|
|
||||||
|
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -369,12 +382,34 @@ class Seq2SeqLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
|
if batch.decoder_attention_mask is not None:
|
||||||
|
# slice to the correct shape
|
||||||
|
decoder_attention_mask = batch.decoder_attention_mask[
|
||||||
|
:, : -batch.padding_right_offset
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
decoder_attention_mask = None
|
||||||
|
|
||||||
|
# check if first forward or not
|
||||||
|
if batch.past_key_values is not None:
|
||||||
|
# Only take the last token
|
||||||
|
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
decoder_input_ids = batch.decoder_input_ids
|
||||||
|
|
||||||
|
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
||||||
|
# internally...
|
||||||
|
if batch.encoder_last_hidden_state is not None:
|
||||||
|
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
|
||||||
|
else:
|
||||||
|
encoder_last_hidden_state = batch.encoder_last_hidden_state
|
||||||
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
logits, encoder_last_hidden_state, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.decoder_input_ids,
|
decoder_input_ids,
|
||||||
batch.decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
batch.encoder_last_hidden_state,
|
encoder_last_hidden_state,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -402,7 +437,6 @@ class Seq2SeqLM(Model):
|
|||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.input_ids,
|
|
||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -414,7 +448,6 @@ class Seq2SeqLM(Model):
|
|||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
input_tokens,
|
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
@ -429,10 +462,8 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(
|
next_token_text = self.decode_token(
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
@ -469,14 +500,10 @@ class Seq2SeqLM(Model):
|
|||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if stopping_criteria.current_tokens == 1:
|
if stopping_criteria.current_tokens == 1:
|
||||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
prefill_token_ids, [float("nan")], prefill_texts
|
[self.tokenizer.bos_token_id],
|
||||||
|
[float("nan")],
|
||||||
|
[self.tokenizer.bos_token],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
@ -487,6 +514,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,10 +528,8 @@ class Seq2SeqLM(Model):
|
|||||||
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
# from the values of the next batch
|
# from the values of the next batch
|
||||||
if len(next_batch_keep_indices) != len(batch):
|
if len(next_batch_keep_indices) != len(batch):
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
|
||||||
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
|
|
||||||
if batch.decoder_attention_mask is not None:
|
if batch.decoder_attention_mask is not None:
|
||||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
|
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
|
||||||
next_batch_keep_indices
|
next_batch_keep_indices
|
||||||
@ -526,7 +552,6 @@ class Seq2SeqLM(Model):
|
|||||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
next_batch_input_ids = batch.input_ids
|
|
||||||
next_batch_attention_mask = batch.attention_mask
|
next_batch_attention_mask = batch.attention_mask
|
||||||
next_batch_decoder_attention_mask = batch.decoder_attention_mask
|
next_batch_decoder_attention_mask = batch.decoder_attention_mask
|
||||||
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
|
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
|
||||||
@ -536,20 +561,14 @@ class Seq2SeqLM(Model):
|
|||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
next_batch_next_token_choosers = batch.next_token_choosers
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
# Update decoder_attention_mask with padding as we added a new token to input_ids
|
# Update decoder_attention_mask as we added a new token to input_ids
|
||||||
if next_batch_decoder_attention_mask is not None:
|
if next_batch_decoder_attention_mask is not None:
|
||||||
next_batch_decoder_attention_mask = torch.cat(
|
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
[
|
|
||||||
next_batch_decoder_attention_mask,
|
|
||||||
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_batch = Seq2SeqLMBatch(
|
next_batch = Seq2SeqLMBatch(
|
||||||
batch_id=batch.batch_id,
|
batch_id=batch.batch_id,
|
||||||
requests=next_batch_requests,
|
requests=next_batch_requests,
|
||||||
input_ids=next_batch_input_ids,
|
input_ids=None,
|
||||||
attention_mask=next_batch_attention_mask,
|
attention_mask=next_batch_attention_mask,
|
||||||
decoder_input_ids=next_batch_decoder_input_ids,
|
decoder_input_ids=next_batch_decoder_input_ids,
|
||||||
decoder_attention_mask=next_batch_decoder_attention_mask,
|
decoder_attention_mask=next_batch_decoder_attention_mask,
|
||||||
@ -562,5 +581,6 @@ class Seq2SeqLM(Model):
|
|||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
max_input_length=next_batch_max_input_length,
|
max_input_length=next_batch_max_input_length,
|
||||||
max_decoder_input_length=next_batch_max_decoder_input_length,
|
max_decoder_input_length=next_batch_max_decoder_input_length,
|
||||||
|
padding_right_offset=batch.padding_right_offset - 1,
|
||||||
)
|
)
|
||||||
return generations, next_batch
|
return generations, next_batch
|
@ -16,11 +16,10 @@ from transformers.models.t5.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import Seq2SeqLM
|
from text_generation_server.models import Seq2SeqLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
)
|
)
|
||||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||||
|
|
||||||
# Only master download weights
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
model = AutoModelForSeq2SeqLM.from_config(config)
|
||||||
@ -228,14 +221,6 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if past_key_values is not None:
|
|
||||||
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
|
||||||
# internally...
|
|
||||||
if encoder_last_hidden_state is not None:
|
|
||||||
encoder_last_hidden_state = [encoder_last_hidden_state]
|
|
||||||
|
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
@ -6,8 +6,8 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
||||||
@ -73,6 +73,7 @@ class Generation:
|
|||||||
token_id: int
|
token_id: int
|
||||||
token_logprob: float
|
token_logprob: float
|
||||||
token_text: str
|
token_text: str
|
||||||
|
token_is_special: bool
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
@ -84,6 +85,7 @@ class Generation:
|
|||||||
token_id=self.token_id,
|
token_id=self.token_id,
|
||||||
token_logprob=self.token_logprob,
|
token_logprob=self.token_logprob,
|
||||||
token_text=self.token_text,
|
token_text=self.token_text,
|
||||||
|
token_is_special=self.token_is_special,
|
||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from text_generation.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
36
server/text_generation_server/utils/__init__.py
Normal file
36
server/text_generation_server/utils/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from text_generation_server.utils.convert import convert_file, convert_files
|
||||||
|
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||||
|
from text_generation_server.utils.hub import (
|
||||||
|
weight_files,
|
||||||
|
weight_hub_files,
|
||||||
|
download_weights,
|
||||||
|
EntryNotFoundError,
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.tokens import (
|
||||||
|
Greedy,
|
||||||
|
NextTokenChooser,
|
||||||
|
Sampling,
|
||||||
|
StoppingCriteria,
|
||||||
|
StopSequenceCriteria,
|
||||||
|
FinishReason,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_file",
|
||||||
|
"convert_files",
|
||||||
|
"initialize_torch_distributed",
|
||||||
|
"weight_files",
|
||||||
|
"weight_hub_files",
|
||||||
|
"download_weights",
|
||||||
|
"EntryNotFoundError",
|
||||||
|
"LocalEntryNotFoundError",
|
||||||
|
"RevisionNotFoundError",
|
||||||
|
"Greedy",
|
||||||
|
"NextTokenChooser",
|
||||||
|
"Sampling",
|
||||||
|
"StoppingCriteria",
|
||||||
|
"StopSequenceCriteria",
|
||||||
|
"FinishReason",
|
||||||
|
]
|
94
server/text_generation_server/utils/convert.py
Normal file
94
server/text_generation_server/utils/convert.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import concurrent
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_size(source_file: Path, target_file: Path):
|
||||||
|
"""
|
||||||
|
Check that two files are close in size
|
||||||
|
"""
|
||||||
|
source_file_size = source_file.stat().st_size
|
||||||
|
target_file_size = target_file.stat().st_size
|
||||||
|
|
||||||
|
if (source_file_size - target_file_size) / source_file_size > 0.01:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"""The file size different is more than 1%:
|
||||||
|
- {source_file}: {source_file_size}
|
||||||
|
- {target_file}: {target_file_size}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
|
||||||
|
remove them
|
||||||
|
"""
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
ptrs[v.data_ptr()].append(k)
|
||||||
|
|
||||||
|
# Iterate over all found memory addresses
|
||||||
|
for ptr, names in ptrs.items():
|
||||||
|
if len(names) > 1:
|
||||||
|
# Multiple tensors are point to the same memory
|
||||||
|
# Only keep the first tensor
|
||||||
|
for name in names[1:]:
|
||||||
|
tensors.pop(name)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file(pt_file: Path, st_file: Path):
|
||||||
|
"""
|
||||||
|
Convert a pytorch file to a safetensors file
|
||||||
|
"""
|
||||||
|
logger.info(f"Convert {pt_file} to {st_file}.")
|
||||||
|
|
||||||
|
pt_state = torch.load(pt_file, map_location="cpu")
|
||||||
|
if "state_dict" in pt_state:
|
||||||
|
pt_state = pt_state["state_dict"]
|
||||||
|
|
||||||
|
remove_shared_pointers(pt_state)
|
||||||
|
|
||||||
|
# Tensors need to be contiguous
|
||||||
|
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
||||||
|
|
||||||
|
st_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(pt_state, str(st_file), metadata={"format": "pt"})
|
||||||
|
|
||||||
|
# Check that both files are close in size
|
||||||
|
check_file_size(pt_file, st_file)
|
||||||
|
|
||||||
|
# Load safetensors state
|
||||||
|
st_state = load_file(str(st_file))
|
||||||
|
for k in st_state:
|
||||||
|
pt_tensor = pt_state[k]
|
||||||
|
st_tensor = st_state[k]
|
||||||
|
if not torch.equal(pt_tensor, st_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_files(pt_files: List[Path], st_files: List[Path]):
|
||||||
|
assert len(pt_files) == len(st_files)
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(max_workers=5)
|
||||||
|
futures = [
|
||||||
|
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
|
||||||
|
for pt_file, st_file in zip(pt_files, st_files)
|
||||||
|
]
|
||||||
|
|
||||||
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
start_time = time.time()
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(futures) - (i + 1)
|
||||||
|
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
||||||
|
|
||||||
|
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
35
server/text_generation_server/utils/dist.py
Normal file
35
server/text_generation_server/utils/dist.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_torch_distributed():
|
||||||
|
rank = int(os.getenv("RANK", "0"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch.distributed import ProcessGroupNCCL
|
||||||
|
|
||||||
|
# Set the device id.
|
||||||
|
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
|
||||||
|
device = rank % torch.cuda.device_count()
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
backend = "nccl"
|
||||||
|
options = ProcessGroupNCCL.Options()
|
||||||
|
options.is_high_priority_stream = True
|
||||||
|
options._timeout = timedelta(seconds=60)
|
||||||
|
else:
|
||||||
|
backend = "gloo"
|
||||||
|
options = None
|
||||||
|
|
||||||
|
# Call the init process.
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.distributed.group.WORLD, rank, world_size
|
165
server/text_generation_server/utils/hub.py
Normal file
165
server/text_generation_server/utils/hub.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from loguru import logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from huggingface_hub.utils import (
|
||||||
|
LocalEntryNotFoundError,
|
||||||
|
EntryNotFoundError,
|
||||||
|
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||||
|
)
|
||||||
|
|
||||||
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
|
|
||||||
|
def weight_hub_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[str]:
|
||||||
|
"""Get the weights filenames on the hub"""
|
||||||
|
api = HfApi()
|
||||||
|
info = api.model_info(model_id, revision=revision)
|
||||||
|
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||||
|
|
||||||
|
if not filenames:
|
||||||
|
raise EntryNotFoundError(
|
||||||
|
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_load_from_cache(
|
||||||
|
model_id: str, revision: Optional[str], filename: str
|
||||||
|
) -> Optional[Path]:
|
||||||
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
object_id = model_id.replace("/", "--")
|
||||||
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
||||||
|
|
||||||
|
if not repo_cache.is_dir():
|
||||||
|
# No cache for this model
|
||||||
|
return None
|
||||||
|
|
||||||
|
refs_dir = repo_cache / "refs"
|
||||||
|
snapshots_dir = repo_cache / "snapshots"
|
||||||
|
no_exist_dir = repo_cache / ".no_exist"
|
||||||
|
|
||||||
|
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||||
|
if refs_dir.is_dir():
|
||||||
|
revision_file = refs_dir / revision
|
||||||
|
if revision_file.exists():
|
||||||
|
with revision_file.open() as f:
|
||||||
|
revision = f.read()
|
||||||
|
|
||||||
|
# Check if file is cached as "no_exist"
|
||||||
|
if (no_exist_dir / revision / filename).is_file():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if revision folder exists
|
||||||
|
if not snapshots_dir.exists():
|
||||||
|
return None
|
||||||
|
cached_shas = os.listdir(snapshots_dir)
|
||||||
|
if revision not in cached_shas:
|
||||||
|
# No cache for this revision and we won't try to return a random revision
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if file exists in cache
|
||||||
|
cached_file = snapshots_dir / revision / filename
|
||||||
|
return cached_file if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
|
def weight_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Get the local files"""
|
||||||
|
# Local model
|
||||||
|
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||||
|
return list(Path(model_id).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
|
except EntryNotFoundError as e:
|
||||||
|
if extension != ".safetensors":
|
||||||
|
raise e
|
||||||
|
# Try to see if there are pytorch weights
|
||||||
|
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
||||||
|
# Change pytorch extension to safetensors extension
|
||||||
|
# It is possible that we have safetensors weights locally even though they are not on the
|
||||||
|
# hub if we converted weights locally without pushing them
|
||||||
|
filenames = [
|
||||||
|
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
|
files = []
|
||||||
|
for filename in filenames:
|
||||||
|
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
|
||||||
|
if not p.exists():
|
||||||
|
raise LocalEntryNotFoundError(
|
||||||
|
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
|
||||||
|
)
|
||||||
|
files.append(p)
|
||||||
|
return files
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for filename in filenames:
|
||||||
|
cache_file = try_to_load_from_cache(
|
||||||
|
model_id, revision=revision, filename=filename
|
||||||
|
)
|
||||||
|
if cache_file is None:
|
||||||
|
raise LocalEntryNotFoundError(
|
||||||
|
f"File {filename} of model {model_id} not found in "
|
||||||
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||||
|
f"Please run `text-generation-server download-weights {model_id}` first."
|
||||||
|
)
|
||||||
|
files.append(cache_file)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def download_weights(
|
||||||
|
filenames: List[str], model_id: str, revision: Optional[str] = None
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Download the safetensors files from the hub"""
|
||||||
|
|
||||||
|
def download_file(filename):
|
||||||
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
|
if local_file is not None:
|
||||||
|
logger.info(f"File {filename} already present in cache.")
|
||||||
|
return Path(local_file)
|
||||||
|
|
||||||
|
logger.info(f"Download file: {filename}")
|
||||||
|
start_time = time.time()
|
||||||
|
local_file = hf_hub_download(
|
||||||
|
filename=filename,
|
||||||
|
repo_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=False,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
|
)
|
||||||
|
return Path(local_file)
|
||||||
|
|
||||||
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
|
start_time = time.time()
|
||||||
|
files = []
|
||||||
|
for i, filename in enumerate(filenames):
|
||||||
|
file = download_file(filename)
|
||||||
|
|
||||||
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||||
|
remaining = len(filenames) - (i + 1)
|
||||||
|
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
||||||
|
|
||||||
|
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
|
||||||
|
files.append(file)
|
||||||
|
|
||||||
|
return files
|
160
server/text_generation_server/utils/tokens.py
Normal file
160
server/text_generation_server/utils/tokens.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
LogitsProcessorList,
|
||||||
|
TemperatureLogitsWarper,
|
||||||
|
TopKLogitsWarper,
|
||||||
|
TopPLogitsWarper,
|
||||||
|
TypicalLogitsWarper,
|
||||||
|
RepetitionPenaltyLogitsProcessor,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class Sampling:
|
||||||
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
|
self.generator = torch.Generator(device)
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, logits):
|
||||||
|
probs = torch.nn.functional.softmax(logits)
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Greedy:
|
||||||
|
def __call__(self, logits):
|
||||||
|
return logits.argmax()
|
||||||
|
|
||||||
|
|
||||||
|
class NextTokenChooser:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
watermark=False,
|
||||||
|
temperature=1.0,
|
||||||
|
repetition_penalty=1.0,
|
||||||
|
top_k=None,
|
||||||
|
top_p=None,
|
||||||
|
typical_p=None,
|
||||||
|
do_sample=False,
|
||||||
|
seed=0,
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
warpers = LogitsProcessorList()
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
sampling = do_sample
|
||||||
|
|
||||||
|
if watermark:
|
||||||
|
warpers.append(WatermarkLogitsProcessor(device=device))
|
||||||
|
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||||
|
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
temperature = float(temperature)
|
||||||
|
warpers.append(TemperatureLogitsWarper(temperature))
|
||||||
|
sampling = True
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
warpers.append(TopKLogitsWarper(top_k=top_k))
|
||||||
|
sampling = True
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
warpers.append(TopPLogitsWarper(top_p=top_p))
|
||||||
|
sampling = True
|
||||||
|
if typical_p is not None and typical_p < 1.0:
|
||||||
|
warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||||
|
sampling = True
|
||||||
|
|
||||||
|
self.warpers = warpers
|
||||||
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
# Warp logits
|
||||||
|
if scores.shape[0] > 1:
|
||||||
|
# only warp the last token logits
|
||||||
|
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
||||||
|
else:
|
||||||
|
scores = self.warpers(input_ids, scores)
|
||||||
|
|
||||||
|
# Compute logprobs
|
||||||
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
|
|
||||||
|
# Choose tokens
|
||||||
|
next_id = self.choice(scores[-1])
|
||||||
|
|
||||||
|
return next_id.view(1, 1), logprobs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "NextTokenChooser":
|
||||||
|
return NextTokenChooser(
|
||||||
|
watermark=pb.watermark,
|
||||||
|
temperature=pb.temperature,
|
||||||
|
repetition_penalty=pb.repetition_penalty,
|
||||||
|
top_k=pb.top_k,
|
||||||
|
top_p=pb.top_p,
|
||||||
|
typical_p=pb.typical_p,
|
||||||
|
do_sample=pb.do_sample,
|
||||||
|
seed=pb.seed,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StopSequenceCriteria:
|
||||||
|
def __init__(self, stop_sequence: str):
|
||||||
|
self.regex = re.compile(f".*{stop_sequence}$")
|
||||||
|
|
||||||
|
def __call__(self, output: str) -> bool:
|
||||||
|
if self.regex.findall(output):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteria:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eos_token_id: int,
|
||||||
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
|
max_new_tokens=20,
|
||||||
|
):
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.current_tokens = 0
|
||||||
|
self.current_output = ""
|
||||||
|
|
||||||
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
self.current_tokens += 1
|
||||||
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
|
if last_token == self.eos_token_id:
|
||||||
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
|
self.current_output += last_output
|
||||||
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
|
if stop_sequence_criteria(self.current_output):
|
||||||
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.StoppingCriteriaParameters,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> "StoppingCriteria":
|
||||||
|
stop_sequence_criterias = [
|
||||||
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
|
]
|
||||||
|
return StoppingCriteria(
|
||||||
|
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
||||||
|
)
|
87
server/text_generation_server/utils/watermark.py
Normal file
87
server/text_generation_server/utils/watermark.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
||||||
|
# available at https://arxiv.org/abs/2301.10226
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import LogitsProcessor
|
||||||
|
|
||||||
|
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||||
|
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gamma: float = GAMMA,
|
||||||
|
delta: float = DELTA,
|
||||||
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
# watermarking parameters
|
||||||
|
self.gamma = gamma
|
||||||
|
self.delta = delta
|
||||||
|
self.rng = torch.Generator(device=device)
|
||||||
|
self.hash_key = hash_key
|
||||||
|
|
||||||
|
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||||
|
assert (
|
||||||
|
input_ids.shape[-1] >= 1
|
||||||
|
), "requires at least a 1 token prefix sequence to seed rng"
|
||||||
|
prev_token = input_ids[-1].item()
|
||||||
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
|
|
||||||
|
def _get_greenlist_ids(
|
||||||
|
self, input_ids: torch.LongTensor, max_value: int
|
||||||
|
) -> list[int]:
|
||||||
|
# seed the rng using the previous tokens/prefix
|
||||||
|
self._seed_rng(input_ids)
|
||||||
|
|
||||||
|
greenlist_size = int(max_value * self.gamma)
|
||||||
|
vocab_permutation = torch.randperm(
|
||||||
|
max_value, device=input_ids.device, generator=self.rng
|
||||||
|
)
|
||||||
|
greenlist_ids = vocab_permutation[:greenlist_size]
|
||||||
|
return greenlist_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_greenlist_mask(
|
||||||
|
scores: torch.FloatTensor, greenlist_token_ids
|
||||||
|
) -> torch.BoolTensor:
|
||||||
|
green_tokens_mask = torch.zeros_like(scores)
|
||||||
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||||
|
final_mask = green_tokens_mask.bool()
|
||||||
|
return final_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _bias_greenlist_logits(
|
||||||
|
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
assert len(input_ids) == 1
|
||||||
|
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
|
||||||
|
green_tokens_mask = self._calc_greenlist_mask(
|
||||||
|
scores=scores, greenlist_token_ids=greenlist_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = self._bias_greenlist_logits(
|
||||||
|
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
||||||
|
)
|
||||||
|
return scores
|
9
supported_models.json
Normal file
9
supported_models.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[
|
||||||
|
"bigscience/bloom",
|
||||||
|
"bigscience/bloomz",
|
||||||
|
"EleutherAI/gpt-neox-20b",
|
||||||
|
"google/flan-ul2",
|
||||||
|
"google/flan-t5-xxl",
|
||||||
|
"OpenAssistant/oasst-sft-1-pythia-12b",
|
||||||
|
"olivierdehaene/optimized-santacoder"
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user