Merge branch 'main' into lewtun-patch-1

This commit is contained in:
OlivierDehaene 2023-03-23 18:03:33 +01:00 committed by GitHub
commit c07acd4fea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 5799 additions and 1414 deletions

View File

@ -8,6 +8,15 @@ on:
tags:
- 'v*'
pull_request:
paths:
- ".github/workflows/build.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
- "Dockerfile"
branches:
- 'main'
@ -15,6 +24,10 @@ jobs:
build-and-push-image:
runs-on: ubuntu-latest
steps:
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with:
install: true
- name: Tailscale
uses: tailscale/github-action@v1
with:
@ -65,5 +78,5 @@ jobs:
platforms: 'linux/amd64'
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/huggingface/text-generation-inference:latest
cache-to: type=inline
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max

View File

@ -3,14 +3,23 @@ name: Server Tests
on:
pull_request:
paths:
- ".github/workflows/tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
jobs:
run_tests:
runs-on: ubuntu-20.04
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
steps:
- uses: actions/checkout@v2
- name: Set up Python
@ -25,19 +34,38 @@ jobs:
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Loading cache.
uses: actions/cache@v2
id: model_cache
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with:
path: ~/.cache/huggingface/
key: models
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install
run: |
make install
- name: Run server tests
run: |
pip install pytest
pytest -sv server/tests
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
- name: Run Rust tests
run: |
cargo test
- name: sccache stats
run: |
/usr/local/bin/sccache --show-stats

317
Cargo.lock generated
View File

@ -8,6 +8,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.20"
@ -34,19 +45,20 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800"
[[package]]
name = "async-stream"
version = "0.3.3"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
checksum = "ad445822218ce64be7a341abfb0b1ea43b5c23aa83902542a4542e78309d8e5e"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965"
dependencies = [
"proc-macro2",
"quote",
@ -83,9 +95,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.4"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967"
dependencies = [
"async-trait",
"axum-core",
@ -109,7 +121,7 @@ dependencies = [
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-http 0.4.0",
"tower-layer",
"tower-service",
]
@ -142,7 +154,7 @@ dependencies = [
"http",
"opentelemetry",
"tower",
"tower-http",
"tower-http 0.3.5",
"tracing",
"tracing-opentelemetry",
]
@ -265,9 +277,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.1.4"
version = "4.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76"
checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5"
dependencies = [
"bitflags",
"clap_derive",
@ -280,9 +292,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.1.0"
version = "4.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8"
checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0"
dependencies = [
"heck",
"proc-macro-error",
@ -293,9 +305,9 @@ dependencies = [
[[package]]
name = "clap_lex"
version = "0.3.1"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade"
checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09"
dependencies = [
"os_str_bytes",
]
@ -349,9 +361,9 @@ dependencies = [
[[package]]
name = "crossbeam-channel"
version = "0.5.6"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521"
checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c"
dependencies = [
"cfg-if",
"crossbeam-utils",
@ -359,9 +371,9 @@ dependencies = [
[[package]]
name = "crossbeam-deque"
version = "0.8.2"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"crossbeam-epoch",
@ -370,9 +382,9 @@ dependencies = [
[[package]]
name = "crossbeam-epoch"
version = "0.9.13"
version = "0.9.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a"
checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695"
dependencies = [
"autocfg",
"cfg-if",
@ -383,9 +395,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
version = "0.8.14"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f"
checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b"
dependencies = [
"cfg-if",
]
@ -575,9 +587,9 @@ dependencies = [
[[package]]
name = "fastrand"
version = "1.8.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499"
checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be"
dependencies = [
"instant",
]
@ -774,7 +786,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "grpc-metadata"
version = "0.1.0"
version = "0.4.0"
dependencies = [
"opentelemetry",
"tonic",
@ -784,9 +796,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.3.15"
version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
dependencies = [
"bytes",
"fnv",
@ -806,6 +818,9 @@ name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash",
]
[[package]]
name = "heck"
@ -839,9 +854,9 @@ checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "http"
version = "0.2.8"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399"
checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
dependencies = [
"bytes",
"fnv",
@ -1004,9 +1019,9 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146"
[[package]]
name = "is-terminal"
version = "0.4.3"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0a45d56fe973d6db23972bf5bc46f988a4a2385deac9cc29572f09daef"
checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857"
dependencies = [
"hermit-abi 0.3.1",
"io-lifetimes",
@ -1093,6 +1108,15 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "mach"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
dependencies = [
"libc",
]
[[package]]
name = "macro_rules_attribute"
version = "0.1.3"
@ -1132,13 +1156,71 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "memoffset"
version = "0.7.1"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
dependencies = [
"autocfg",
]
[[package]]
name = "metrics"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]]
name = "metrics-exporter-prometheus"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
dependencies = [
"hyper",
"indexmap",
"ipnet",
"metrics",
"metrics-util",
"parking_lot",
"portable-atomic",
"quanta",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "metrics-macros"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "metrics-util"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"hashbrown",
"metrics",
"num_cpus",
"parking_lot",
"portable-atomic",
"quanta",
"sketches-ddsketch",
]
[[package]]
name = "mime"
version = "0.3.16"
@ -1172,14 +1254,14 @@ dependencies = [
[[package]]
name = "mio"
version = "0.8.5"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
dependencies = [
"libc",
"log",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.42.0",
"windows-sys 0.45.0",
]
[[package]]
@ -1268,9 +1350,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "once_cell"
version = "1.17.0"
version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "onig"
@ -1514,6 +1596,12 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "portable-atomic"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@ -1565,9 +1653,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.11.6"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698"
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537"
dependencies = [
"bytes",
"prost-derive",
@ -1575,9 +1663,9 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.11.6"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e"
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12"
dependencies = [
"bytes",
"heck",
@ -1597,9 +1685,9 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.11.6"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d"
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b"
dependencies = [
"anyhow",
"itertools 0.10.5",
@ -1610,14 +1698,29 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.11.6"
version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788"
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88"
dependencies = [
"bytes",
"prost",
]
[[package]]
name = "quanta"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
dependencies = [
"crossbeam-utils",
"libc",
"mach",
"once_cell",
"raw-cpuid",
"wasi 0.10.2+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quote"
version = "1.0.23"
@ -1657,6 +1760,15 @@ dependencies = [
"getrandom",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags",
]
[[package]]
name = "rayon"
version = "1.6.1"
@ -1736,15 +1848,6 @@ version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.14"
@ -1973,18 +2076,24 @@ dependencies = [
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
dependencies = [
"libc",
]
[[package]]
name = "slab"
version = "0.4.7"
name = "sketches-ddsketch"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef"
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
[[package]]
name = "slab"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
dependencies = [
"autocfg",
]
@ -1997,9 +2106,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]]
name = "socket2"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd"
checksum = "95a21dcece9b5991cfd1ece74654c8e3d0d5aab499d359b0395e38229c0bb5a3"
dependencies = [
"libc",
"winapi",
@ -2053,9 +2162,9 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.107"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
@ -2081,16 +2190,15 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.3.0"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
dependencies = [
"cfg-if",
"fastrand",
"libc",
"redox_syscall",
"remove_dir_all",
"winapi",
"rustix",
"windows-sys 0.42.0",
]
[[package]]
@ -2104,7 +2212,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "0.2.1"
version = "0.4.0"
dependencies = [
"futures",
"grpc-metadata",
@ -2121,9 +2229,9 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "0.2.1"
version = "0.4.0"
dependencies = [
"clap 4.1.4",
"clap 4.1.8",
"ctrlc",
"float_eq",
"reqwest",
@ -2136,18 +2244,21 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "0.2.1"
version = "0.4.0"
dependencies = [
"async-stream",
"axum",
"axum-tracing-opentelemetry",
"clap 4.1.4",
"clap 4.1.8",
"futures",
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher",
"opentelemetry",
"opentelemetry-otlp",
"parking_lot",
"rand",
"reqwest",
"serde",
"serde_json",
"text-generation-client",
@ -2155,6 +2266,7 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-stream",
"tower-http 0.3.5",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
@ -2193,9 +2305,9 @@ dependencies = [
[[package]]
name = "thread_local"
version = "1.1.6"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f297120ff9d4efe680df143d5631bba9c75fa371992b7fcb33eb3453cb0a07"
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
dependencies = [
"cfg-if",
"once_cell",
@ -2203,12 +2315,11 @@ dependencies = [
[[package]]
name = "time"
version = "0.1.45"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a"
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
dependencies = [
"libc",
"wasi 0.10.0+wasi-snapshot-preview1",
"winapi",
]
@ -2264,9 +2375,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.25.0"
version = "1.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af"
checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64"
dependencies = [
"autocfg",
"bytes",
@ -2279,7 +2390,7 @@ dependencies = [
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.42.0",
"windows-sys 0.45.0",
]
[[package]]
@ -2315,9 +2426,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313"
dependencies = [
"futures-core",
"pin-project-lite",
@ -2326,9 +2437,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.6"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6a3b08b64e6dfad376fa2432c7b1f01522e37a623c3050bc95db2d3ff21583"
checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2"
dependencies = [
"bytes",
"futures-core",
@ -2417,12 +2528,30 @@ dependencies = [
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
@ -2627,9 +2756,9 @@ dependencies = [
[[package]]
name = "utoipa"
version = "3.0.1"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
checksum = "a15f6da6a2b471134ca44b7d18e8a76d73035cf8b3ed24c4dd5ca6a63aa439c5"
dependencies = [
"indexmap",
"serde",
@ -2639,9 +2768,9 @@ dependencies = [
[[package]]
name = "utoipa-gen"
version = "3.0.1"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
checksum = "6f2e33027986a4707b3f5c37ed01b33d0e5a53da30204b52ff18f80600f1d0ec"
dependencies = [
"proc-macro-error",
"proc-macro2",
@ -2712,9 +2841,9 @@ dependencies = [
[[package]]
name = "wasi"
version = "0.10.0+wasi-snapshot-preview1"
version = "0.10.2+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]]
name = "wasi"

View File

@ -1,4 +1,15 @@
FROM rust:1.67 as router-builder
FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
WORKDIR /usr/src
FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
@ -6,26 +17,15 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
WORKDIR /usr/src
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
WORKDIR /usr/src/router
RUN cargo install --path .
FROM rust:1.67 as launcher-builder
WORKDIR /usr/src
COPY rust-toolchain.toml rust-toolchain.toml
COPY launcher launcher
WORKDIR /usr/src/launcher
RUN cargo install --path .
RUN cargo build --release
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
NUM_SHARD=1 \
@ -68,9 +69,9 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
# Install router
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -13,25 +13,25 @@ server-dev:
cd server && make run-dev
router-dev:
cd router && cargo run
cd router && cargo run -- --port 8080
integration-tests: install-router install-launcher
cargo test
python-tests:
cd server && pytest tests
cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
run-bloom-560m:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
run-bloom-560m-quantize:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
download-bloom:
text-generation-server download-weights bigscience/bloom
HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
run-bloom:
text-generation-launcher --model-id bigscience/bloom --num-shard 8
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080

View File

@ -39,27 +39,30 @@ to power LLMs api-inference widgets.
## Features
- Serve the most popular Large Language Models with a simple launcher
- Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE)
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences
- Log probabilities
- Distributed tracing with Open Telemetry
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
## Officially supported models
## Officially supported architectures
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
Other models are supported on a best effort basis using:
Other architectures are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
@ -80,24 +83,42 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
You can then query the model using either the `/generate` or `/generate_stream` routes:
```shell
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json'
```
```shell
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json'
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
or from Python:
```shell
pip install text-generation
```
```python
from text_generation import Client
client = Client("http://127.0.0.1:8080")
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
text = ""
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
if not response.token.special:
text += response.token.text
print(text)
```
### API documentation
@ -191,7 +212,7 @@ Be aware that the official Docker image has them enabled by default.
### Download
First you need to download the weights:
It is advised to download the weights ahead of time with the following command:
```shell
make download-bloom

158
clients/python/.gitignore vendored Normal file
View File

@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors

6
clients/python/Makefile Normal file
View File

@ -0,0 +1,6 @@
unit-tests:
python -m pytest --cov=text_generation tests
install:
pip install pip --upgrade
pip install -e .

196
clients/python/README.md Normal file
View File

@ -0,0 +1,196 @@
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
## Get Started
### Install
```shell
pip install text-generation
```
### Inference API Usage
```python
from text_generation import InferenceAPIClient
client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Hugging Face Inference Endpoint usage
```python
from text_generation import Client
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import AsyncClient
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Types
```python
# Prompt tokens
class PrefillToken:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response:
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse:
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
```

1038
clients/python/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
[tool.poetry]
name = "text-generation"
version = "0.3.1"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
readme = "README.md"
homepage = "https://github.com/huggingface/text-generation-inference"
repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
python = "^3.7"
pydantic = "^1.10"
aiohttp = "^3.8"
huggingface-hub = ">= 0.12, < 1.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
pytest-asyncio = "^0.17.2"
pytest-cov = "^3.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -0,0 +1,51 @@
import pytest
from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture
def fake_model():
return "fake/model"
@pytest.fixture
def unsupported_model():
return "gpt2"
@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"
@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"
@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="text-generation-tests", library_version=__version__
)

View File

@ -0,0 +1,133 @@
import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False
)
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
assert len(response.details.best_of_sequences) == 1
assert response.details.best_of_sequences[0].seed is not None
def test_generate_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
def test_generate_stream_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False
)
@pytest.mark.asyncio
async def test_generate_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
await client.generate("test")
@pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@pytest.mark.asyncio
async def test_generate_stream_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
async for _ in client.generate_stream("test"):
pass
@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass

View File

@ -0,0 +1,64 @@
from text_generation.errors import (
parse_error,
GenerationError,
IncompleteGenerationError,
OverloadedError,
ValidationError,
BadRequestError,
ShardNotReadyError,
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
)
def test_generation_error():
payload = {"error_type": "generation", "error": "test"}
assert isinstance(parse_error(400, payload), GenerationError)
def test_incomplete_generation_error():
payload = {"error_type": "incomplete_generation", "error": "test"}
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
def test_overloaded_error():
payload = {"error_type": "overloaded", "error": "test"}
assert isinstance(parse_error(400, payload), OverloadedError)
def test_validation_error():
payload = {"error_type": "validation", "error": "test"}
assert isinstance(parse_error(400, payload), ValidationError)
def test_bad_request_error():
payload = {"error": "test"}
assert isinstance(parse_error(400, payload), BadRequestError)
def test_shard_not_ready_error():
payload = {"error": "test"}
assert isinstance(parse_error(403, payload), ShardNotReadyError)
assert isinstance(parse_error(424, payload), ShardNotReadyError)
def test_shard_timeout_error():
payload = {"error": "test"}
assert isinstance(parse_error(504, payload), ShardTimeoutError)
def test_not_found_error():
payload = {"error": "test"}
assert isinstance(parse_error(404, payload), NotFoundError)
def test_rate_limit_exceeded_error():
payload = {"error": "test"}
assert isinstance(parse_error(429, payload), RateLimitExceededError)
def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)

View File

@ -0,0 +1,34 @@
import pytest
from text_generation import (
InferenceAPIClient,
InferenceAPIAsyncClient,
Client,
AsyncClient,
)
from text_generation.errors import NotSupportedError
from text_generation.inference_api import get_supported_models
def test_get_supported_models():
assert isinstance(get_supported_models(), list)
def test_client(flan_t5_xxl):
client = InferenceAPIClient(flan_t5_xxl)
assert isinstance(client, Client)
def test_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIClient(unsupported_model)
def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(flan_t5_xxl)
assert isinstance(client, AsyncClient)
def test_async_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIAsyncClient(unsupported_model)

View File

@ -0,0 +1,82 @@
import pytest
from text_generation.types import Parameters, Request
from text_generation.errors import ValidationError
def test_parameters_validation():
# Test best_of
Parameters(best_of=1)
with pytest.raises(ValidationError):
Parameters(best_of=0)
with pytest.raises(ValidationError):
Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
with pytest.raises(ValidationError):
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=0.5)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)
with pytest.raises(ValidationError):
Parameters(top_p=1)
# Test truncate
Parameters(truncate=1)
with pytest.raises(ValidationError):
Parameters(truncate=0)
with pytest.raises(ValidationError):
Parameters(truncate=-1)
# Test typical_p
Parameters(typical_p=0.5)
with pytest.raises(ValidationError):
Parameters(typical_p=0)
with pytest.raises(ValidationError):
Parameters(typical_p=-1)
with pytest.raises(ValidationError):
Parameters(typical_p=1)
def test_request_validation():
Request(inputs="test")
with pytest.raises(ValidationError):
Request(inputs="")
Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)

View File

@ -0,0 +1,18 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.0"
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -0,0 +1,487 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
"""Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Iterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
"""Asynchronous Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response

View File

@ -0,0 +1,106 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardTimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Args:
status_code (`int`):
HTTP status code
payload (`Dict[str, str]`):
Json payload
Returns:
Exception: parsed exception
"""
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return ShardTimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -0,0 +1,154 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
"""
Get the list of supported text-generation models from GitHub
Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
"""
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(
base_url, headers=headers, timeout=timeout
)
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(
base_url, headers=headers, timeout=timeout
)

View File

@ -0,0 +1,223 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation.errors import ValidationError
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text
return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = []
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False
# Get generation details
details: bool = False
@validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
@validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
@validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
@validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
class Request(BaseModel):
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool = False
@validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
@validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value
# Prompt tokens
class PrefillToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse(BaseModel):
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]

View File

@ -11,7 +11,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.2.1"
"version": "0.4.0"
},
"paths": {
"/generate": {
@ -38,10 +38,7 @@
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/GenerateResponse"
}
"$ref": "#/components/schemas/GenerateResponse"
}
}
}
@ -51,10 +48,7 @@
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
@ -67,10 +61,7 @@
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
@ -83,10 +74,7 @@
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
@ -99,10 +87,7 @@
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
@ -136,12 +121,9 @@
"200": {
"description": "Generated Text",
"content": {
"text/event-stream ": {
"text/event-stream": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/StreamResponse"
}
"$ref": "#/components/schemas/StreamResponse"
}
}
}
@ -149,12 +131,9 @@
"422": {
"description": "Input validation error",
"content": {
"text/event-stream ": {
"text/event-stream": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error"
@ -165,12 +144,9 @@
"424": {
"description": "Generation Error",
"content": {
"text/event-stream ": {
"text/event-stream": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation"
@ -181,12 +157,9 @@
"429": {
"description": "Model is overloaded",
"content": {
"text/event-stream ": {
"text/event-stream": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded"
@ -197,12 +170,9 @@
"500": {
"description": "Incomplete generation",
"content": {
"text/event-stream ": {
"text/event-stream": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation"
@ -213,17 +183,90 @@
},
"deprecated": false
}
},
"/metrics": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics",
"responses": {
"200": {
"description": "Prometheus Metrics",
"content": {
"text/plain": {
"schema": {
"type": "string"
}
}
}
}
},
"deprecated": false
}
}
},
"components": {
"schemas": {
"BestOfSequence": {
"type": "object",
"required": [
"generated_text",
"finish_reason",
"generated_tokens",
"prefill",
"tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_text": {
"type": "string",
"example": "test"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/PrefillToken"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
},
"tokens": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
}
}
},
"Details": {
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
"generated_tokens",
"prefill",
"tokens"
],
"properties": {
"best_of_sequences": {
"type": "array",
"items": {
"$ref": "#/components/schemas/BestOfSequence"
}
},
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
@ -235,13 +278,14 @@
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
"$ref": "#/components/schemas/PrefillToken"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42
"example": 42,
"nullable": true
},
"tokens": {
"type": "array",
@ -254,11 +298,15 @@
"ErrorResponse": {
"type": "object",
"required": [
"error"
"error",
"error_type"
],
"properties": {
"error": {
"type": "string"
},
"error_type": {
"type": "string"
}
}
},
@ -273,6 +321,13 @@
"GenerateParameters": {
"type": "object",
"properties": {
"best_of": {
"type": "integer",
"default": "null",
"example": 1,
"nullable": true,
"exclusiveMinimum": 0.0
},
"details": {
"type": "boolean",
"default": "true"
@ -297,9 +352,19 @@
"nullable": true,
"exclusiveMinimum": 0.0
},
"return_full_text": {
"type": "boolean",
"default": "null",
"example": false,
"nullable": true
},
"seed": {
"type": "integer",
"format": "int64"
"format": "int64",
"default": "null",
"example": "null",
"nullable": true,
"exclusiveMinimum": 0.0
},
"stop": {
"type": "array",
@ -335,6 +400,26 @@
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
},
"truncate": {
"type": "integer",
"default": "null",
"example": "null",
"nullable": true
},
"typical_p": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.95,
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
},
"watermark": {
"type": "boolean",
"default": "false",
"example": true
}
}
},
@ -368,6 +453,31 @@
}
}
},
"PrefillToken": {
"type": "object",
"required": [
"id",
"text",
"logprob"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0
},
"logprob": {
"type": "number",
"format": "float",
"example": -0.34,
"nullable": true
},
"text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": {
"type": "object",
"required": [
@ -386,7 +496,8 @@
"seed": {
"type": "integer",
"format": "int64",
"example": 42
"example": 42,
"nullable": true
}
}
},
@ -415,7 +526,8 @@
"required": [
"id",
"text",
"logprob"
"logprob",
"special"
],
"properties": {
"id": {
@ -429,6 +541,10 @@
"example": -0.34,
"nullable": true
},
"special": {
"type": "boolean",
"example": "false"
},
"text": {
"type": "string",
"example": "test"

View File

@ -1,6 +1,6 @@
[package]
name = "text-generation-launcher"
version = "0.2.1"
version = "0.4.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"

View File

@ -1,6 +1,7 @@
use clap::Parser;
use serde_json::Value;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read};
use std::path::Path;
use std::process::ExitCode;
@ -12,7 +13,7 @@ use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
/// App Configuration
#[derive(Parser, Debug)]
@ -23,13 +24,21 @@ struct Args {
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
sharded: Option<bool>,
#[clap(long, env)]
num_shard: Option<usize>,
#[clap(long, env)]
quantize: bool,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "20", long, env)]
@ -43,38 +52,112 @@ struct Args {
#[clap(default_value = "29500", long, env)]
master_port: usize,
#[clap(long, env)]
huggingface_hub_cache: Option<String>,
#[clap(long, env)]
weights_cache_override: Option<String>,
#[clap(long, env)]
disable_custom_kernels: bool,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
watermark_delta: Option<f32>,
}
fn main() -> ExitCode {
// Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
let Args {
model_id,
revision,
sharded,
num_shard,
quantize,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_waiting_tokens,
port,
shard_uds_path,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
json_output,
otlp_endpoint,
} = Args::parse();
cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args;
if json_output {
tracing_subscriber::fmt().json().init();
// get the number of shards given `sharded` and `num_shard`
let num_shard = if let Some(sharded) = sharded {
// sharded is set
match sharded {
// sharded is set and true
true => {
match num_shard {
None => {
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices");
}
n_devices
}
Some(num_shard) => {
// we can't have only one shard while sharded
if num_shard <= 1 {
panic!("`sharded` is true but `num_shard` <= 1");
}
num_shard
}
}
}
// sharded is set and false
false => {
let num_shard = num_shard.unwrap_or(1);
// we can't have more than one shard while not sharded
if num_shard != 1 {
panic!("`sharded` is false but `num_shard` != 1");
}
num_shard
}
}
} else {
tracing_subscriber::fmt().compact().init();
match num_shard {
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
None => num_cuda_devices().unwrap_or(1),
Some(num_shard) => num_shard,
}
};
if num_shard < 1 {
panic!("`num_shard` cannot be < 1");
}
// By default we only have one master shard
let num_shard = num_shard.unwrap_or(1);
if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes");
}
// Signal handler
let running = Arc::new(AtomicBool::new(true));
@ -84,6 +167,121 @@ fn main() -> ExitCode {
})
.expect("Error setting Ctrl-C handler");
// Check if model_id is a local model
let local_path = Path::new(&model_id);
let is_local_model = local_path.exists() && local_path.is_dir();
// Download weights for sharded models
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the shard
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// Start process
tracing::info!("Starting download process.");
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unknown status.");
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
}
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
@ -99,6 +297,8 @@ fn main() -> ExitCode {
let revision = revision.clone();
let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone();
let huggingface_hub_cache = huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
@ -113,6 +313,11 @@ fn main() -> ExitCode {
num_shard,
master_addr,
master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
watermark_gamma,
watermark_delta,
otlp_endpoint,
status_sender,
shutdown,
@ -161,8 +366,14 @@ fn main() -> ExitCode {
"text-generation-router".to_string(),
"--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
max_best_of.to_string(),
"--max-stop-sequences".to_string(),
max_stop_sequences.to_string(),
"--max-input-length".to_string(),
max_input_length.to_string(),
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
"--max-batch-size".to_string(),
max_batch_size.to_string(),
"--max-waiting-tokens".to_string(),
@ -185,6 +396,12 @@ fn main() -> ExitCode {
argv.push(otlp_endpoint);
}
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
let mut webserver = match Popen::create(
&argv,
PopenConfig {
@ -232,7 +449,7 @@ fn main() -> ExitCode {
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err);
tracing::error!("Shard {rank} failed:\n{err}");
exit_code = ExitCode::FAILURE;
break;
};
@ -275,6 +492,11 @@ fn shard_manager(
world_size: usize,
master_addr: String,
master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
@ -319,43 +541,54 @@ fn shard_manager(
shard_argv.push(otlp_endpoint);
}
let mut env = vec![
("RANK".into(), rank.to_string().into()),
("WORLD_SIZE".into(), world_size.to_string().into()),
("MASTER_ADDR".into(), master_addr.into()),
("MASTER_PORT".into(), master_port.to_string().into()),
("SAFETENSORS_FAST_GPU".into(), "1".into()),
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
];
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
// If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
if let Some(weights_cache_override) = weights_cache_override {
env.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};
// If the NCCL_SHM_DISABLE env var is set, pass it to the shard
// needed when running NCCL inside a docker container and when you can't increase shm size
if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") {
env.push(("NCCL_SHM_DISABLE".into(), nccl_shm_disalbe.into()));
};
// If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
};
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard {}", rank);
tracing::info!("Starting shard {rank}");
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
@ -419,17 +652,17 @@ fn shard_manager(
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank);
tracing::info!("Shard {rank} terminated");
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank);
tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
@ -449,3 +682,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv();
}
fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(',').count();
return Some(n_devices);
}
None
}

View File

@ -1,122 +1,142 @@
{
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 10264,
"logprob": null,
"text": "Test"
"text": "Test",
"logprob": null
},
{
"id": 8821,
"logprob": -11.894989,
"text": " request"
"text": " request",
"logprob": -11.894989
}
],
"seed": null,
"tokens": [
{
"id": 17,
"text": ".",
"logprob": -1.8267672,
"text": "."
"special": false
},
{
"id": 1587,
"text": "get",
"logprob": -2.4674969,
"text": "get"
"special": false
},
{
"id": 11,
"text": "(",
"logprob": -1.906001,
"text": "("
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -1.2279545,
"text": "\""
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -4.170299,
"text": "action"
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -0.32478866,
"text": "\""
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -1.0773665,
"text": ")"
"special": false
},
{
"id": 30,
"text": ";",
"logprob": -0.27640742,
"text": ";"
"special": false
},
{
"id": 837,
"text": "\n ",
"logprob": -1.6970354,
"text": "\n "
"special": false
},
{
"id": 1320,
"text": " if",
"logprob": -1.4495516,
"text": " if"
"special": false
},
{
"id": 375,
"text": " (",
"logprob": -0.23609057,
"text": " ("
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -1.1916996,
"text": "action"
"special": false
},
{
"id": 3535,
"text": " ==",
"logprob": -0.8918753,
"text": " =="
"special": false
},
{
"id": 5109,
"text": " null",
"logprob": -0.3933342,
"text": " null"
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -0.43212673,
"text": ")"
"special": false
},
{
"id": 731,
"text": " {",
"logprob": -0.17702064,
"text": " {"
"special": false
},
{
"id": 1260,
"text": "\n ",
"logprob": -0.07027565,
"text": "\n "
"special": false
},
{
"id": 10519,
"text": " throw",
"logprob": -1.3915029,
"text": " throw"
"special": false
},
{
"id": 2084,
"text": " new",
"logprob": -0.04201372,
"text": " new"
"special": false
},
{
"id": 150858,
"text": " RuntimeException",
"logprob": -1.7329919,
"text": " RuntimeException"
"special": false
}
]
},
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
}

View File

@ -14,6 +14,7 @@ pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
special: bool,
}
#[derive(Deserialize)]
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
{
assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text);
assert_eq!(token.special, expected_token.special);
if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);

View File

@ -1,117 +1,137 @@
{
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
"text": "<pad>",
"logprob": null
}
],
"seed": null,
"tokens": [
{
"id": 259,
"text": " ",
"logprob": -1.3656927,
"text": ""
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -2.6551573,
"text": "\"\"\""
"special": false
},
{
"id": 46138,
"text": "Test",
"logprob": -1.8059857,
"text": "Test"
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -1.2102449,
"text": "the"
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -1.6057279,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -3.6060903,
"text": "contents"
"special": false
},
{
"id": 304,
"text": " of",
"logprob": -0.5270343,
"text": "of"
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -0.62522805,
"text": "the"
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -1.4069618,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -2.621994,
"text": "contents"
"special": false
},
{
"id": 304,
"text": " of",
"logprob": -1.3172221,
"text": "of"
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -0.3501925,
"text": "the"
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -0.7219573,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -1.0494149,
"text": "contents"
"special": false
},
{
"id": 260,
"text": ".",
"logprob": -1.0803378,
"text": "."
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -0.32933083,
"text": ""
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -0.11268901,
"text": "\"\"\""
"special": false
},
{
"id": 2978,
"text": " test",
"logprob": -1.5846587,
"text": "test"
"special": false
},
{
"id": 290,
"text": "_",
"logprob": -0.49796978,
"text": "_"
"special": false
},
{
"id": 4125,
"text": "test",
"logprob": -2.0026445,
"text": "test"
"special": false
}
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
}

View File

@ -34,12 +34,16 @@ message NextTokenChooserParameters {
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 4;
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 5;
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 6;
float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
}
message StoppingCriteriaParameters {
@ -54,12 +58,10 @@ message Request {
uint64 id = 1;
/// The generation context
string inputs = 2;
/// The number of tokens inside inputs
uint32 input_length = 3;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
NextTokenChooserParameters parameters = 3;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
StoppingCriteriaParameters stopping_parameters = 4;
}
message Batch {
@ -108,8 +110,10 @@ message Generation {
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text
GeneratedText generated_text = 6;
GeneratedText generated_text = 7;
}
message PrefillRequest {

View File

@ -1,6 +1,6 @@
[package]
name = "text-generation-router"
version = "0.2.1"
version = "0.4.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver"
@ -19,17 +19,21 @@ axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1"
rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

View File

@ -1,6 +1,6 @@
[package]
name = "text-generation-client"
version = "0.2.1"
version = "0.4.0"
edition = "2021"
[dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "grpc-metadata"
version = "0.1.0"
version = "0.4.0"
edition = "2021"
[dependencies]

View File

@ -1,9 +1,9 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc;
use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
@ -81,6 +81,7 @@ impl Infer {
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
tracing::error!("{err}");
err
})?;
@ -138,7 +139,7 @@ impl Infer {
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token { id, text, logprob })
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
// Push last token
@ -172,10 +173,48 @@ impl Infer {
})
} else {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
Err(err)
}
}
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN;
for (i, response) in infer_responses.iter().enumerate() {
// mean logprobs of the generated tokens
let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence
if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob;
}
}
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
}
/// Batching logic
@ -190,7 +229,11 @@ async fn batching_task(
shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
let limit_min_batch_size = if max_batch_size > 1 {
(max_batch_size / 2) as u32
} else {
0
};
// Infinite loop
loop {
@ -201,7 +244,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -212,6 +255,7 @@ async fn batching_task(
// Get current batch info
let batch_size = batch.size;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
@ -234,17 +278,17 @@ async fn batching_task(
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span)
.await;
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
@ -262,35 +306,66 @@ async fn batching_task(
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries)
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
}
}
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)]
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
match future.await {
let start_time = Instant::now();
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
match client.decode(batches).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None
}
}
@ -303,6 +378,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
@ -340,6 +416,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
@ -388,7 +465,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
@ -406,3 +483,14 @@ pub enum InferError {
#[error("Incomplete generation")]
IncompleteGeneration,
}
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
}
}
}

View File

@ -12,6 +12,9 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
@ -40,39 +43,64 @@ pub(crate) struct GenerateParameters {
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default = "default_do_sample")]
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32,
#[serde(default)]
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
#[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>,
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")]
pub details: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>,
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
do_sample: default_do_sample(),
typical_p: None,
do_sample: false,
max_new_tokens: default_max_new_tokens(),
stop: vec![],
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
}
@ -86,14 +114,46 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
#[serde(default)]
#[allow(dead_code)]
pub stream: bool,
}
impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
}
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
#[schema(example = "false")]
special: bool,
}
#[derive(Serialize, ToSchema)]
@ -108,16 +168,32 @@ pub(crate) enum FinishReason {
StopSequence,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
#[schema(example = "test")]
pub generated_text: String,
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct Details {
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>,
pub tokens: Option<Vec<Token>>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
}
#[derive(Serialize, ToSchema)]
@ -134,7 +210,7 @@ pub(crate) struct StreamDetails {
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
}
@ -149,6 +225,6 @@ pub(crate) struct StreamResponse {
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String,
pub error_type: String,
}

View File

@ -1,4 +1,5 @@
/// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue;
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
@ -7,9 +8,11 @@ use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
@ -20,8 +23,14 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "20", long, env)]
@ -38,6 +47,8 @@ struct Args {
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
}
fn main() -> Result<(), std::io::Error> {
@ -46,7 +57,10 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_waiting_tokens,
port,
@ -55,17 +69,37 @@ fn main() -> Result<(), std::io::Error> {
validation_workers,
json_output,
otlp_endpoint,
cors_allow_origin,
} = args;
if validation_workers == 0 {
panic!("validation_workers must be > 0");
}
// Download and instantiate tokenizer
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Tokenizer instance
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
@ -75,6 +109,27 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
let model_info: serde_json::Value =
serde_json::from_str(&model_info).expect("unable to parse model info");
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match model_info.get("pipeline_tag") {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
};
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
@ -91,14 +146,19 @@ fn main() -> Result<(), std::io::Error> {
// Run server
server::run(
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
max_waiting_tokens,
sharded_client,
tokenizer,
validation_workers,
addr,
cors_allow_origin,
)
.await;
Ok(())

View File

@ -132,6 +132,7 @@ impl State {
// Push entry in the queue
self.entries.push((self.next_id, entry));
self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
// Get the next batch
@ -164,7 +165,8 @@ impl State {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
@ -172,7 +174,6 @@ impl State {
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
input_length: entry.request.input_length,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
});
@ -190,6 +191,8 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}
}
@ -223,14 +226,15 @@ mod tests {
Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
input_length: 0,
parameters: NextTokenChooserParameters {
temperature: 0.0,
top_k: 0,
top_p: 0.0,
typical_p: 0.0,
do_sample: false,
seed: 0,
repetition_penalty: 0.0,
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
max_new_tokens: 0,

View File

@ -1,17 +1,20 @@
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Infer, StreamDetails, StreamResponse, Token, Validation,
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible;
use std::net::SocketAddr;
use text_generation_client::ShardedClient;
@ -19,29 +22,61 @@ use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() {
req.parameters.return_full_text = Some(default_return_full_text.0)
}
// switch on stream
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
}
}
/// Health check method
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request
infer
.generate(GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
@ -57,15 +92,15 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"})),
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
@ -82,23 +117,64 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
async fn generate(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();
// Inference
let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details;
let response = infer.generate(req.0).await?;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
};
// Token details
let details = match details {
true => Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill),
tokens: Some(response.tokens),
seed: response.generated_text.seed,
}),
true => {
// convert best_of_responses
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
responses
.into_iter()
.map(|response: InferResponse| {
// Add prompt if return_full_text
let mut output_text = response.generated_text.text;
if let Some(prompt) = &add_prompt {
output_text = prompt.clone() + &output_text;
}
BestOfSequence {
generated_text: output_text,
finish_reason: FinishReason::from(
response.generated_text.finish_reason,
),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}
})
.collect()
});
Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
})
}
false => None,
};
@ -111,6 +187,15 @@ async fn generate(
// Headers
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
headers.insert(
"x-total-time",
total_time.as_millis().to_string().parse().unwrap(),
@ -141,9 +226,26 @@ async fn generate(
span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
// Send response
let mut output_text = response.generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
let response = GenerateResponse {
generated_text: response.generated_text.text,
generated_text: output_text,
details,
};
Ok((headers, Json(response)))
@ -156,20 +258,20 @@ async fn generate(
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [StreamResponse],
content_type="text/event-stream "),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"}),
content_type="text/event-stream "),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
(status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"),
)
)]
#[instrument(
@ -186,118 +288,177 @@ async fn generate(
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let span = tracing::Span::current();
let start_time = Instant::now();
let compute_characters = req.0.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
let stream = async_stream::stream! {
// Inference
let mut end_reached = false;
let mut error = false;
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details;
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
match response {
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
// StreamResponse
let stream_token = StreamResponse {
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
match response {
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
// StreamResponse
let stream_token = StreamResponse {
token,
generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
generated_text: None,
details: None,
};
generated_text,
start,
queued,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
yield Ok(Event::default().json_data(stream_token).unwrap())
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// StreamResponse
end_reached = true;
let stream_token = StreamResponse {
token,
generated_text: Some(generated_text.text),
details
};
// StreamResponse
end_reached = true;
yield Ok(Event::default().json_data(stream_token).unwrap())
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
let stream_token = StreamResponse {
token,
generated_text: Some(output_text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap());
break;
}
}
}
}
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err))
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
break;
}
}
}
},
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
}
},
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err))
}
}
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
yield Ok(Event::from(err));
}
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err))
yield Ok(Event::from(err));
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
compat_return_full_text: bool,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize,
max_waiting_tokens: usize,
client: ShardedClient,
tokenizer: Tokenizer,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
) {
// OpenAPI documentation
#[derive(OpenApi)]
@ -305,13 +466,16 @@ pub async fn run(
paths(
generate,
generate_stream,
metrics,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
@ -333,7 +497,14 @@ pub async fn run(
struct ApiDoc;
// Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let validation = Validation::new(
validation_workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
);
let infer = Infer::new(
client,
validation,
@ -342,16 +513,33 @@ pub async fn run(
max_concurrent_requests,
);
// Prometheus handler
let builder = PrometheusBuilder::new();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate))
.route("/", post(compat_generate))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
.route("/health", get(health))
.route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(opentelemetry_tracing_layer());
.layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
// Run server
axum::Server::bind(&addr)
@ -415,6 +603,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
status_code,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
)
}
@ -425,6 +614,7 @@ impl From<InferError> for Event {
Event::default()
.json_data(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
})
.unwrap()
}

View File

@ -1,3 +1,4 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng;
@ -5,33 +6,44 @@ use rand::Rng;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
/// Validation
#[derive(Debug, Clone)]
pub struct Validation {
/// maximum value for the best_of parameter
#[allow(dead_code)]
max_best_of: usize,
/// Channel to communicate with the background validation task
sender: mpsc::Sender<ValidationRequest>,
sender: mpsc::UnboundedSender<ValidationRequest>,
}
impl Validation {
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
pub(crate) fn new(
workers: usize,
tokenizer: Tokenizer,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
// Create channel
let (validation_sender, validation_receiver) = mpsc::channel(128);
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
// Launch background validation task
tokio::spawn(validation_task(
workers,
tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
validation_receiver,
));
Self {
max_best_of,
sender: validation_sender,
}
}
@ -48,12 +60,25 @@ impl Validation {
// Unwrap is safe here
self.sender
.send((request, sender, Span::current()))
.await
.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap()
}
/// Validate the best_of parameter
#[instrument(skip_all)]
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
if self.max_best_of == 1 && best_of != 1 {
return Err(ValidationError::BestOfDisabled);
}
if best_of > self.max_best_of {
return Err(ValidationError::BestOf(self.max_best_of, best_of));
}
Ok(best_of)
}
}
/// Validation task
@ -61,8 +86,10 @@ impl Validation {
async fn validation_task(
workers: usize,
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
max_total_tokens: usize,
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
) {
let mut workers_senders = Vec::with_capacity(workers);
@ -75,7 +102,13 @@ async fn validation_task(
// Spawn worker
tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
validation_worker(
tokenizer_clone,
max_stop_sequences,
max_input_length,
max_total_tokens,
worker_receiver,
)
});
}
@ -95,7 +128,9 @@ async fn validation_task(
/// the tokenizer
fn validation_worker(
tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Seed rng
@ -106,7 +141,16 @@ fn validation_worker(
parent_span.in_scope(|| {
response_tx
.send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| {
validate(
request,
&tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
&mut rng,
)
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
err
}),
@ -119,21 +163,39 @@ fn validation_worker(
fn validate(
request: GenerateRequest,
tokenizer: &Tokenizer,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters {
best_of,
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
truncate,
seed,
watermark,
..
} = request.parameters;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature);
@ -144,30 +206,42 @@ fn validate(
return Err(ValidationError::RepetitionPenalty);
}
let top_p = top_p.unwrap_or(1.0);
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP);
}
// Different because the proto default value is 0 while it is not a valid value
// Different because the proto default value is not a valid value
// for the user
let top_k: u32 = match top_k {
None => Ok(0),
Some(top_k) => {
if top_k <= 0 {
let top_p = top_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TopP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let typical_p = typical_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TypicalP);
}
Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK);
}
Ok(top_k as u32)
}
}?;
Ok(value as u32)
})
.unwrap_or(Ok(0))?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
}
if stop_sequences.len() > MAX_STOP_SEQUENCES {
if stop_sequences.len() > max_stop_sequences {
return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES,
max_stop_sequences,
stop_sequences.len(),
));
}
@ -175,41 +249,82 @@ fn validate(
// If seed is None, assign a random one
let seed = match seed {
None => rng.gen(),
Some(seed) => seed,
Some(seed) => {
if best_of > 1 {
return Err(BestOfSeed);
}
seed
}
};
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) {
Ok(encoding) => {
let input_length = encoding.len();
if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length))
} else {
// Return ValidGenerateRequest
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
seed,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
};
Ok(ValidGenerateRequest {
inputs: request.inputs,
input_length: input_length as u32,
parameters,
stopping_parameters,
})
}
}
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Check if truncate is strictly positive and less than max_input_length
let truncate = truncate
.map(|value| {
if value == 0 || value > max_input_length {
return Err(ValidationError::Truncate(max_input_length, value));
}
Ok(Some(value))
})
.unwrap_or(Ok(None))?;
// Get the number of tokens in the input
let mut encoding = tokenizer
.encode(request.inputs.clone(), true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
let (inputs, input_length) = if let Some(truncate) = truncate {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} else {
(request.inputs, encoding.len())
};
if input_length > max_input_length {
return Err(ValidationError::InputLength(max_input_length, input_length));
}
let total_tokens = input_length + max_new_tokens as usize;
if total_tokens > max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
max_total_tokens,
input_length,
max_new_tokens,
));
}
// Return ValidGenerateRequest
let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
top_k,
top_p,
typical_p,
do_sample,
seed,
watermark,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
};
metrics::histogram!("tgi_request_input_length", input_length as f64);
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest {
inputs,
parameters,
stopping_parameters,
})
}
type ValidationRequest = (
@ -221,26 +336,43 @@ type ValidationRequest = (
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
}
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("temperature must be strictly positive")]
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
BestOf(usize, usize),
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
#[error("you must use sampling when `best_of` is > 1")]
BestOfSampling,
#[error("`seed` must not be set when `best_of` > 1")]
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`temperature` must be strictly positive")]
Temperature,
#[error("repetition_penalty must be strictly positive")]
#[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")]
#[error("`top_p` must be > 0.0 and < 1.0")]
TopP,
#[error("top_k must be strictly positive")]
#[error("`top_k` must be strictly positive")]
TopK,
#[error("max_new_tokens must be strictly positive and <= {0}")]
MaxNewTokens(u32),
#[error("inputs must have less than {1} tokens. Given: {0}")]
#[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
Truncate(usize, usize),
#[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32),
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize),
#[error("stop supports up to {0} stop sequences. Given: {1}")]
#[error("`inputs` cannot be empty")]
EmptyInput,
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize),
#[error("tokenizer error {0}")]
Tokenizer(String),

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class

View File

@ -1,20 +1,22 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true
rm -rf transformers || true
rm -rf transformers-text_generation_inference || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
unzip text_generation_inference.zip
rm text_generation_inference.zip
mv transformers-text_generation_inference transformers
rm -rf transformers-$(transformers_commit) || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip
unzip $(transformers_commit).zip
rm $(transformers_commit).zip
mv transformers-$(transformers_commit) transformers
cd transformers && python setup.py install
install-torch:
@ -26,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

417
server/poetry.lock generated
View File

@ -145,30 +145,30 @@ testing = ["protobuf (>=3.6.0)"]
[[package]]
name = "grpcio"
version = "1.51.1"
version = "1.51.3"
description = "HTTP/2-based RPC framework"
category = "main"
optional = false
python-versions = ">=3.7"
[package.extras]
protobuf = ["grpcio-tools (>=1.51.1)"]
protobuf = ["grpcio-tools (>=1.51.3)"]
[[package]]
name = "grpcio-reflection"
version = "1.51.1"
version = "1.51.3"
description = "Standard Protobuf Reflection Service for gRPC"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
grpcio = ">=1.51.1"
grpcio = ">=1.51.3"
protobuf = ">=4.21.6"
[[package]]
name = "grpcio-status"
version = "1.51.1"
version = "1.51.3"
description = "Status proto mapping for gRPC"
category = "main"
optional = false
@ -176,22 +176,30 @@ python-versions = ">=3.6"
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.51.1"
grpcio = ">=1.51.3"
protobuf = ">=4.21.6"
[[package]]
name = "grpcio-tools"
version = "1.51.1"
version = "1.51.3"
description = "Protobuf code generator for gRPC"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.dependencies]
grpcio = ">=1.51.1"
grpcio = ">=1.51.3"
protobuf = ">=4.21.6,<5.0dev"
setuptools = "*"
[[package]]
name = "hf-transfer"
version = "0.1.2"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "idna"
version = "3.4"
@ -428,7 +436,7 @@ testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "protobuf"
version = "4.21.12"
version = "4.22.0"
description = ""
category = "main"
optional = false
@ -511,7 +519,7 @@ torch = ["torch"]
[[package]]
name = "setuptools"
version = "67.2.0"
version = "67.4.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
category = "main"
optional = false
@ -567,7 +575,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
[[package]]
name = "typing-extensions"
version = "4.4.0"
version = "4.5.0"
description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main"
optional = false
@ -610,7 +618,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "wrapt"
version = "1.14.1"
version = "1.15.0"
description = "Module for decorators, wrappers and monkey patching."
category = "main"
optional = false
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321"
content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d"
[metadata.files]
accelerate = [
@ -760,106 +768,127 @@ grpc-interceptor = [
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"},
]
grpcio = [
{file = "grpcio-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:cc2bece1737b44d878cc1510ea04469a8073dbbcdd762175168937ae4742dfb3"},
{file = "grpcio-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:e223a9793522680beae44671b9ed8f6d25bbe5ddf8887e66aebad5e0686049ef"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:24ac1154c4b2ab4a0c5326a76161547e70664cd2c39ba75f00fc8a2170964ea2"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4ef09f8997c4be5f3504cefa6b5c6cc3cf648274ce3cede84d4342a35d76db6"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8a0b77e992c64880e6efbe0086fe54dfc0bbd56f72a92d9e48264dcd2a3db98"},
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:eacad297ea60c72dd280d3353d93fb1dcca952ec11de6bb3c49d12a572ba31dd"},
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:16c71740640ba3a882f50b01bf58154681d44b51f09a5728180a8fdc66c67bd5"},
{file = "grpcio-1.51.1-cp310-cp310-win32.whl", hash = "sha256:29cb97d41a4ead83b7bcad23bdb25bdd170b1e2cba16db6d3acbb090bc2de43c"},
{file = "grpcio-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:9ff42c5620b4e4530609e11afefa4a62ca91fa0abb045a8957e509ef84e54d30"},
{file = "grpcio-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:bc59f7ba87972ab236f8669d8ca7400f02a0eadf273ca00e02af64d588046f02"},
{file = "grpcio-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3c2b3842dcf870912da31a503454a33a697392f60c5e2697c91d133130c2c85d"},
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22b011674090594f1f3245960ced7386f6af35485a38901f8afee8ad01541dbd"},
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d680356a975d9c66a678eb2dde192d5dc427a7994fb977363634e781614f7c"},
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:094e64236253590d9d4075665c77b329d707b6fca864dd62b144255e199b4f87"},
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:257478300735ce3c98d65a930bbda3db172bd4e00968ba743e6a1154ea6edf10"},
{file = "grpcio-1.51.1-cp311-cp311-win32.whl", hash = "sha256:5a6ebcdef0ef12005d56d38be30f5156d1cb3373b52e96f147f4a24b0ddb3a9d"},
{file = "grpcio-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f9b0023c2c92bebd1be72cdfca23004ea748be1813a66d684d49d67d836adde"},
{file = "grpcio-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cd3baccea2bc5c38aeb14e5b00167bd4e2373a373a5e4d8d850bd193edad150c"},
{file = "grpcio-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:17ec9b13cec4a286b9e606b48191e560ca2f3bbdf3986f91e480a95d1582e1a7"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:fbdbe9a849854fe484c00823f45b7baab159bdd4a46075302281998cb8719df5"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31bb6bc7ff145e2771c9baf612f4b9ebbc9605ccdc5f3ff3d5553de7fc0e0d79"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e473525c28251558337b5c1ad3fa969511e42304524a4e404065e165b084c9e4"},
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:6f0b89967ee11f2b654c23b27086d88ad7bf08c0b3c2a280362f28c3698b2896"},
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7942b32a291421460d6a07883033e392167d30724aa84987e6956cd15f1a21b9"},
{file = "grpcio-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:f96ace1540223f26fbe7c4ebbf8a98e3929a6aa0290c8033d12526847b291c0f"},
{file = "grpcio-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f1fec3abaf274cdb85bf3878167cfde5ad4a4d97c68421afda95174de85ba813"},
{file = "grpcio-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:0e1a9e1b4a23808f1132aa35f968cd8e659f60af3ffd6fb00bcf9a65e7db279f"},
{file = "grpcio-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:6df3b63538c362312bc5fa95fb965069c65c3ea91d7ce78ad9c47cab57226f54"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:172405ca6bdfedd6054c74c62085946e45ad4d9cec9f3c42b4c9a02546c4c7e9"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:506b9b7a4cede87d7219bfb31014d7b471cfc77157da9e820a737ec1ea4b0663"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb93051331acbb75b49a2a0fd9239c6ba9528f6bdc1dd400ad1cb66cf864292"},
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5dca372268c6ab6372d37d6b9f9343e7e5b4bc09779f819f9470cd88b2ece3c3"},
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:471d39d3370ca923a316d49c8aac66356cea708a11e647e3bdc3d0b5de4f0a40"},
{file = "grpcio-1.51.1-cp38-cp38-win32.whl", hash = "sha256:75e29a90dc319f0ad4d87ba6d20083615a00d8276b51512e04ad7452b5c23b04"},
{file = "grpcio-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:f1158bccbb919da42544a4d3af5d9296a3358539ffa01018307337365a9a0c64"},
{file = "grpcio-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:59dffade859f157bcc55243714d57b286da6ae16469bf1ac0614d281b5f49b67"},
{file = "grpcio-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:dad6533411d033b77f5369eafe87af8583178efd4039c41d7515d3336c53b4f1"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:4c4423ea38a7825b8fed8934d6d9aeebdf646c97e3c608c3b0bcf23616f33877"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dc5354e38e5adf2498312f7241b14c7ce3484eefa0082db4297189dcbe272e6"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97d67983189e2e45550eac194d6234fc38b8c3b5396c153821f2d906ed46e0ce"},
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:538d981818e49b6ed1e9c8d5e5adf29f71c4e334e7d459bf47e9b7abb3c30e09"},
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9235dcd5144a83f9ca6f431bd0eccc46b90e2c22fe27b7f7d77cabb2fb515595"},
{file = "grpcio-1.51.1-cp39-cp39-win32.whl", hash = "sha256:aacb54f7789ede5cbf1d007637f792d3e87f1c9841f57dd51abf89337d1b8472"},
{file = "grpcio-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:2b170eaf51518275c9b6b22ccb59450537c5a8555326fd96ff7391b5dd75303c"},
{file = "grpcio-1.51.1.tar.gz", hash = "sha256:e6dfc2b6567b1c261739b43d9c59d201c1b89e017afd9e684d85aa7a186c9f7a"},
{file = "grpcio-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:f601aaeae18dab81930fb8d4f916b0da21e89bb4b5f7367ef793f46b4a76b7b0"},
{file = "grpcio-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:eef0450a4b5ed11feab639bf3eb1b6e23d0efa9b911bf7b06fb60e14f5f8a585"},
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:82b0ad8ac825d4bb31bff9f638557c045f4a6d824d84b21e893968286f88246b"},
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3667c06e37d6cd461afdd51cefe6537702f3d1dc5ff4cac07e88d8b4795dc16f"},
{file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3709048fe0aa23dda09b3e69849a12055790171dab9e399a72ea8f9dfbf9ac80"},
{file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:200d69857f9910f7458b39b9bcf83ee4a180591b40146ba9e49314e3a7419313"},
{file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cd9a5e68e79c5f031500e67793048a90209711e0854a9ddee8a3ce51728de4e5"},
{file = "grpcio-1.51.3-cp310-cp310-win32.whl", hash = "sha256:6604f614016127ae10969176bbf12eb0e03d2fb3d643f050b3b69e160d144fb4"},
{file = "grpcio-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:e95c7ccd4c5807adef1602005513bf7c7d14e5a41daebcf9d8d30d8bf51b8f81"},
{file = "grpcio-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:5e77ee138100f0bb55cbd147840f87ee6241dbd25f09ea7cd8afe7efff323449"},
{file = "grpcio-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:68a7514b754e38e8de9075f7bb4dee919919515ec68628c43a894027e40ddec4"},
{file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c1b9f8afa62ff265d86a4747a2990ec5a96e4efce5d5888f245a682d66eca47"},
{file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de30f0b417744288cec65ec8cf84b8a57995cf7f1e84ccad2704d93f05d0aae"},
{file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b69c7adc7ed60da1cb1b502853db61f453fc745f940cbcc25eb97c99965d8f41"},
{file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d81528ffe0e973dc840ec73a4132fd18b8203ad129d7410155d951a0a7e4f5d0"},
{file = "grpcio-1.51.3-cp311-cp311-win32.whl", hash = "sha256:040eb421613b57c696063abde405916dd830203c184c9000fc8c3b3b3c950325"},
{file = "grpcio-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:2a8e17286c4240137d933b8ca506465472248b4ce0fe46f3404459e708b65b68"},
{file = "grpcio-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d5cd1389669a847555df54177b911d9ff6f17345b2a6f19388707b7a9f724c88"},
{file = "grpcio-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:be1bf35ce82cdbcac14e39d5102d8de4079a1c1a6a06b68e41fcd9ef64f9dd28"},
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:5eed34994c095e2bf7194ffac7381c6068b057ef1e69f8f08db77771350a7566"},
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9a7d88082b2a17ae7bd3c2354d13bab0453899e0851733f6afa6918373f476"},
{file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c8abbc5f837111e7bd619612eedc223c290b0903b952ce0c7b00840ea70f14"},
{file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:165b05af77e6aecb4210ae7663e25acf234ba78a7c1c157fa5f2efeb0d6ec53c"},
{file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:54e36c2ee304ff15f2bfbdc43d2b56c63331c52d818c364e5b5214e5bc2ad9f6"},
{file = "grpcio-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:cd0daac21d9ef5e033a5100c1d3aa055bbed28bfcf070b12d8058045c4e821b1"},
{file = "grpcio-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:2fdd6333ce96435408565a9dbbd446212cd5d62e4d26f6a3c0feb1e3c35f1cc8"},
{file = "grpcio-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:54b0c29bdd9a3b1e1b61443ab152f060fc719f1c083127ab08d03fac5efd51be"},
{file = "grpcio-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:ffaaf7e93fcb437356b5a4b23bf36e8a3d0221399ff77fd057e4bc77776a24be"},
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:eafbe7501a3268d05f2e450e1ddaffb950d842a8620c13ec328b501d25d2e2c3"},
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:881ecb34feabf31c6b3b9bbbddd1a5b57e69f805041e5a2c6c562a28574f71c4"},
{file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e860a3222139b41d430939bbec2ec9c3f6c740938bf7a04471a9a8caaa965a2e"},
{file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:49ede0528e9dac7e8a9fe30b16c73b630ddd9a576bf4b675eb6b0c53ee5ca00f"},
{file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6972b009638b40a448d10e1bc18e2223143b8a7aa20d7def0d78dd4af4126d12"},
{file = "grpcio-1.51.3-cp38-cp38-win32.whl", hash = "sha256:5694448256e3cdfe5bd358f1574a3f2f51afa20cc834713c4b9788d60b7cc646"},
{file = "grpcio-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:3ea4341efe603b049e8c9a5f13c696ca37fcdf8a23ca35f650428ad3606381d9"},
{file = "grpcio-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:6c677581ce129f5fa228b8f418cee10bd28dd449f3a544ea73c8ba590ee49d0b"},
{file = "grpcio-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:30e09b5e0531685e176f49679b6a3b190762cc225f4565e55a899f5e14b3aa62"},
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c831f31336e81243f85b6daff3e5e8a123302ce0ea1f2726ad752fd7a59f3aee"},
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cd2e4cefb724cab1ba2df4b7535a9980531b9ec51b4dbb5f137a1f3a3754ef0"},
{file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7a0d0bf44438869d307f85a54f25a896ad6b4b0ca12370f76892ad732928d87"},
{file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c02abd55409bfb293371554adf6a4401197ec2133dd97727c01180889014ba4d"},
{file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2f8ff75e61e1227ba7a3f16b2eadbcc11d0a54096d52ab75a6b88cfbe56f55d1"},
{file = "grpcio-1.51.3-cp39-cp39-win32.whl", hash = "sha256:6c99a73a6260bdf844b2e5ddad02dcd530310f80e1fa72c300fa19c1c7496962"},
{file = "grpcio-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:22bdfac4f7f27acdd4da359b5e7e1973dc74bf1ed406729b07d0759fde2f064b"},
{file = "grpcio-1.51.3.tar.gz", hash = "sha256:be7b2265b7527bb12109a7727581e274170766d5b3c9258d4e466f4872522d7a"},
]
grpcio-reflection = [
{file = "grpcio-reflection-1.51.1.tar.gz", hash = "sha256:c07a93c0c36ef88fe475744289863b4787005eff4de0cc04213ecad718b01aae"},
{file = "grpcio_reflection-1.51.1-py3-none-any.whl", hash = "sha256:b70af764a83e42a44f65df1edb232e972ab69e72bc7fbbad481e66c29a9d8cb8"},
{file = "grpcio-reflection-1.51.3.tar.gz", hash = "sha256:5adca16f0a6cd403efa3b5f8f8a493eea6a37dee9473b178fad0a60efa68bc67"},
{file = "grpcio_reflection-1.51.3-py3-none-any.whl", hash = "sha256:52b037f831908468afc89c60e591d0a2bbce24a393d908c44a6d53091e90fc41"},
]
grpcio-status = [
{file = "grpcio-status-1.51.1.tar.gz", hash = "sha256:ac2617a3095935ebd785e2228958f24b10a0d527a0c9eb5a0863c784f648a816"},
{file = "grpcio_status-1.51.1-py3-none-any.whl", hash = "sha256:a52cbdc4b18f325bfc13d319ae7c7ae7a0fee07f3d9a005504d6097896d7a495"},
{file = "grpcio-status-1.51.3.tar.gz", hash = "sha256:71792c550356ba94e162c70818719ae6d67d960bdd03a9db5ff68faba2927f6c"},
{file = "grpcio_status-1.51.3-py3-none-any.whl", hash = "sha256:d68d0956c16b6ea466f13c27075f126ef2cd8f0f97527d70056c64b0084357e3"},
]
grpcio-tools = [
{file = "grpcio-tools-1.51.1.tar.gz", hash = "sha256:8e62d23d3fed9d4f81738f98dd193dbd2e21aed4a8f0dd715e75b5439e649727"},
{file = "grpcio_tools-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:ecf1494cb695afead36995534f787761ee33fb9e116b23030113a37fe6057a83"},
{file = "grpcio_tools-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:16b8b915625dc6eb2ea7efdfb06f1fae44a9066c9016453a2ca120c034f33090"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:d5e033c04b416afcddd5231b3ff94a34fb5d26fba2416eb940e69b05f22cfd25"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a218f64e667f3332b74080bdc5440aaf0fa6700ae07a0b54ecf085aaef2aa9f"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b186183515ad6b8584ffe4bd820b72b00f6e7d121fb1c36294edeea9092313"},
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ccd37165d7a3e93f460096a2eb62b7a9c1ebe5c424eaee42d8e92740d0c8f6bc"},
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:531586c5598a99658249f3c5e92826d6d2bb117abd6ffc88527d1e1d9eaef924"},
{file = "grpcio_tools-1.51.1-cp310-cp310-win32.whl", hash = "sha256:392ad4cd004f7b843cf7d916d9a15b2d6585965bfef235be1c88d8f8649777e5"},
{file = "grpcio_tools-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:14e82c2b3ee7e300611c2c729d411b3b911e4cca5f4ec14787457a2fb72ff9d4"},
{file = "grpcio_tools-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:2281180490c475d09b7aa05dabafa5e09de9902176931e7295113f636c2b5360"},
{file = "grpcio_tools-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:c4649af7f5d9553975ee66b6bfae20a84be779f13e163fa835e782961895e63c"},
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f06bb0753b7cecbff154b523cfb8f45dee2c31b0a4c72bed7da44c57f1cba113"},
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a671466158ed74c07ee070fb940ed783acf59ba6e6e53cb4de8fd63819c6c7f"},
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:048793747339f327ea091d8f022c6756d89713d8080dffde5ce7380cc348ea8e"},
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f6caf36e7752728329a28f93afec7c4ec9015fc1c6e4460bd1eb0f3737e1c55a"},
{file = "grpcio_tools-1.51.1-cp311-cp311-win32.whl", hash = "sha256:67b304282cad38642587ebae68617e450e1ad4fa1c0c8b19e9e30274dbb32716"},
{file = "grpcio_tools-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:674b340f2f7bb2adbc3f15144bd37ce5ea83239f78b68dbbd0ea3cba00107e2b"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:055819992ddd30c642a7fd6f344a03747be3afa95cb910f8a2e5efaabd41cde5"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:4e3249a2ec435b3b972610c66c8a714c188844500d564c910f57a2771dc61978"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:794f26a09b70f4f101df5cf54c6c12dc1b65747ab1dee5bda02c2991389ade56"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4957f1ffa16598aa5379505fcbaeb47d65693a46b0817f4ee61db76707092aeb"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9906fb6bf6d9c30c23d85153f12d130f44325afe8f9ebe58aa7a6c82ecade9d8"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87bc5f3e3698c65907d397003c64d25c3ea84e3d6aa46dac133bd98bf66835ee"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a66b3a5d18a7615f0f828b72e2d2935751459c89cc4725e56bdfb3d2cd93281f"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:566809d9942e78821b279af70f3cf159a328127f9f3d5fee8d83ad8b2d27b2fe"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:aab24a342642329de38139cb26f8492882ca0d8551bb87f6530bcc613945a0d0"},
{file = "grpcio_tools-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:6b83d7fc2597c6d392c225177d1fbbcff74900f8cc40b33236987fd1ff841330"},
{file = "grpcio_tools-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:79c06d2577cb4d977922bbf01234de3b20f73d1784d3cbe3179deee1bdb9a60b"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e9abc03d67793b1bf33dc766caa69a3333f9db029869ba6e8fc6cd9c251c0080"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64d8ad369417759f5fdb8ffb7cbd6374fecc06ab51c9a226dee9bbd7d311c3b5"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de51a0a71845b854f6a5967756c893c96bd03e37f39e5dce87b4f409dac36ee2"},
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9dfe6c12b0e2c07f6a4a91a9912ef4e5bd007672533891a44e6f433ffbf7c3b1"},
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:27113b354f7587684eb55125733e6e5be1f489458abfe12344dabd918d8dcc54"},
{file = "grpcio_tools-1.51.1-cp38-cp38-win32.whl", hash = "sha256:98777b5031f1b3c58b688815ffa83435c103b2152c26eb144f80f4a4bb34addb"},
{file = "grpcio_tools-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:1c44b57a6770b78a1eafe355878ff1ec59a2fa07455a2cbd522c071eedae04d4"},
{file = "grpcio_tools-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:49624394805568acd7d767dea5a00d970fca5ad8f395fe0161eeea0de5133eba"},
{file = "grpcio_tools-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:6d6626a6e4dbe843df96dc8c08dd244d2191a75324f54bfa4ebaa3e76b0b1958"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:b4fb8ed6d29f2d6cf03ef99ffaad635bbc132a59be77013691392fe557e67144"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8cc862a1ad30f94528d66cc6f95fb9e659005e568313e54a23550535b649573"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e72a30be1746ea0749a8486d0ca0120c0b2757fe84fc246a5144b1ef66d7b89"},
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:331a897306adeec3c67470431ea8d8b4972b689d32966f94506d91f4dac20952"},
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f336ad9be661d92fa45940e74e8ff3d78e67ebe9b4f7ea8774b2d680c17aeb6c"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
{file = "grpcio-tools-1.51.3.tar.gz", hash = "sha256:4fea28e3dd31871579a57058796a78093c75b74b74e9de2f2b7a7fd9a443d403"},
{file = "grpcio_tools-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:779ac1ad2258b8debaa45595bfb3814806ed8880e3ea7f194e551d76a6255969"},
{file = "grpcio_tools-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:83bf605fe2b3591d3c8a78646f37c72c5832c4dd84b5f92405c17cb10b136be6"},
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:35f885c5afd8e6a77d320f5a9624b439a93f9be2b87fa7b7948c1ad7b2ba0894"},
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b639fb79a4d28ce494ae40e5695bf1e2cb4a05f205fc433c46b2049ab4d99"},
{file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c6b145587d6062e2335f0b3286501dd6853a1ea50bd466a913351b7c48e5f20"},
{file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:046c0b1e372d4acf552aa0c8f5e830f019d67b75f25aeb0968d15fbdd3eaabd3"},
{file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc90b0287908c46281eb61933acaa1b96a575d0160fc98b5c64b9dec46f60d1"},
{file = "grpcio_tools-1.51.3-cp310-cp310-win32.whl", hash = "sha256:8e9df40db7a0edd403b539cc142d6114270e35debf723a5b4a7a93d5c30fffc0"},
{file = "grpcio_tools-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:077adaee431c2b040dd77923964577087c32e828908e8fa2e53f8e003ad408c9"},
{file = "grpcio_tools-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:b50f9b8a6482a90c1a41e731a879a130f7dea267065d0a06f47c9160ce5d01c3"},
{file = "grpcio_tools-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:89a68adcb4238aba69f3a364ac02c9a46e55b9e3fd8af1c6f384079abfa9347c"},
{file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d177da43e7f6fde6715df4a3015ae13158166bc2845ac7f9cfb526eafb41b8"},
{file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:793f9edef82f600a3324f8a3d8cd8318a8d02f28fb54f8236cbb35ce0928d186"},
{file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f7583735542ced7d30baec6cc21bffeaffcec1523bf807e8f8f0047113b6d30a"},
{file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f2df233a3e7db23d9b516cb5e2bfe029465f40a72978bee0584e44e7860ea73f"},
{file = "grpcio_tools-1.51.3-cp311-cp311-win32.whl", hash = "sha256:7427939455735fbf2ea88c37f1585c9c8b809eec7b447642f34465eb4d26020b"},
{file = "grpcio_tools-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:ba76d15fd149b575170fa32a1f6a9ff2b38ff9db223229a8ad6f53450a452688"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d2212c682529263b3c9e903092d0ccbb9fc6afba820e4c2fa52c2c27720cdcae"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:405656b3cf9639427e6c30a795570cba4a7c06b88a3145866f7d2c05b7e048b4"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:3c445a064b2ef3d3475e26e2add8ddb4ac2933741ecddf71d5b071a3ad078db4"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7b3374f4a6579c58d16a5fab2e6b4e9bb8625a034a7f4cd6024f4d1cc12f2a0"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e8df08b65f9379c3f103147b29542b0141ca84e77d0eee9114ca5f9b3f0d23"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2fade12de08923b350475ca16d0d0bd68578c30fce89147aa0f94ef5759bc5a9"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d4ffb6325ed489065dbdca764cf37c3a29376bc657874116c9af788d7a0d2ee4"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:f8d17271fc58ed3503dd571c79917e126deca51f85f093770a9606e806aac9dc"},
{file = "grpcio_tools-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ef849687c7f2bd7f3277edc7c7cafc7042823d0fb078e3c01c861eb0c96ed181"},
{file = "grpcio_tools-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:7fd18d8d211fbfd337fc12e5bdd57e62368f636addf901d290e68a39f1dfea38"},
{file = "grpcio_tools-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:233fc56f054424232e2086f444004413e33c699174ce6ee0e279c25227243fec"},
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:867fa1973fa8b0772077c15425f122f672a18b1c53709a8a2bff9d056db4c20e"},
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b486a99bdf2722e68a9d59769389e2fb86878b6f293be5111f7678e364a0c359"},
{file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8bbf412c357999f88d87f421fd48b4b114fc037fec7bbaed0cb7620c24a5e44"},
{file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1166744c40821bb0aa605d2af2287fac367756f858a3d18f4c3d25bc0b92757b"},
{file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:781896c488e07b9463196045e6725e52d018cd7d0e1062d4ab1eee2647ca9170"},
{file = "grpcio_tools-1.51.3-cp38-cp38-win32.whl", hash = "sha256:35c1ee7c766eb586f04ba41fa7711eb847767eb277a1737998374ac57768f1f0"},
{file = "grpcio_tools-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:584b201fb39307dcb1affcf2647656a0e6244423ef1659cc6caa3ff85c5ae5c1"},
{file = "grpcio_tools-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:e02231e21029f716a1d23a0b5e664fa243d147da33a3f55088a9529b860aa4ac"},
{file = "grpcio_tools-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:fbb742e10bd548031b8d80f7c28eb70c7c3a9850f8e99c98cd496f19a05f9fee"},
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a836a72c657f751244cdb358c3461a89627e6d02654079d2450cfe361800428c"},
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb554408e0ec5ff5201013f268726d9eef8e5bd1fd4b4e09c46c0b4a9de8b64c"},
{file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:158c5bfe7e157fd9a944bde9f7dfe3b468416666e4fade77cd17caa3edc8bd81"},
{file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:715c792679224171c0584e9f235b921d76f8990deb38b0d1215d0469301d9cd9"},
{file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ece44f42b10e0bceb49235be1e361e1ee69afee7f938c82fb656a601a4a720e3"},
{file = "grpcio_tools-1.51.3-cp39-cp39-win32.whl", hash = "sha256:980e632710ba05e04364c6f276e905d5d367437f1ce2265ce7b96b5c1eac5693"},
{file = "grpcio_tools-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:5f4c47b14e66f80365cd5667ecc2f7fb0eb91e02c4e54362041b758feaa00511"},
]
hf-transfer = [
{file = "hf_transfer-0.1.2-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:2b9189a4a460646ee135ee771f39c0f695d3d5bf08b7ff1dcfe374227520e994"},
{file = "hf_transfer-0.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:654fcaba4e7084caa1e97430982ea968935a72916ee0f4afc60e356f89774099"},
{file = "hf_transfer-0.1.2-cp310-none-win_amd64.whl", hash = "sha256:eb29e7b3707b5cac02e689c89111685ebcdaa3cebba02eb7ac1b0f076357da72"},
{file = "hf_transfer-0.1.2-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0bfca9bd84e925e978a0f157df488704c17a0b9ad240b2859262faba0c74cd40"},
{file = "hf_transfer-0.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d00c5473b35227b2f113fd43ff13cbac9539f2e6779fa0680a887b0aac31c389"},
{file = "hf_transfer-0.1.2-cp311-none-win_amd64.whl", hash = "sha256:1aaf5937aa433b7d09ce5bf60967ec22b7d3982957b00516a8dc2aaa66384372"},
{file = "hf_transfer-0.1.2-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:b0aa760a55995ad59ea17e395babafdc56c4e664be0c2d2055664199dd913da1"},
{file = "hf_transfer-0.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:889dd15e8472daf66e266eb056e31a485af3c35f95a483bb43489a0f6e44c359"},
{file = "hf_transfer-0.1.2-cp37-none-win_amd64.whl", hash = "sha256:30df586e18ec8a8e67e3201b9038210d94cb3c03c1cbd97673b9c78ede227178"},
{file = "hf_transfer-0.1.2-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc97eb97f929f96bed896cd3af9bbdf121c15ac6d63524b9fc9312fd2929099a"},
{file = "hf_transfer-0.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:583c2c80210a60dafed9a81ba50c389878aee6c34b2dd375cd84522658f29ad8"},
{file = "hf_transfer-0.1.2-cp38-none-win_amd64.whl", hash = "sha256:6dff58f50d1435b0346f31a32f1f9e2301986521c1d0b51e47a3c82b96d02156"},
{file = "hf_transfer-0.1.2-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d6db1a8f539133f7a893bb32721916fe72b4d2aa3eb7604581ba1f03b8167c90"},
{file = "hf_transfer-0.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f284e3f775d215c9a8d3d1c6f6b1001b1e7990d73ae5fd9aea6c9bce9ea79285"},
{file = "hf_transfer-0.1.2-cp39-none-win_amd64.whl", hash = "sha256:8625beabebc582eafc4141a5ecb9f1183b728d4f63767f01fdcf1e2fbafe6d43"},
{file = "hf_transfer-0.1.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947dd1b8b22ac10723b2887ed4b5ef929f7d4dd850b0d66c0c6954a9a85afb06"},
{file = "hf_transfer-0.1.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90a020f41dfae4629186c284888cd5adbebe402e2497a88351416ab93c7df9a8"},
{file = "hf_transfer-0.1.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5eb89698746a29805bfc60126b9a008e6ba08a82ef9bb122a6544e84f748e8a4"},
{file = "hf_transfer-0.1.2.tar.gz", hash = "sha256:6bf847f4c19c7d8d9f9bbb8a7ed52e1271bbf0c1bd920357db0c274ccc69f21d"},
]
idna = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
@ -965,20 +994,19 @@ pluggy = [
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
protobuf = [
{file = "protobuf-4.21.12-cp310-abi3-win32.whl", hash = "sha256:b135410244ebe777db80298297a97fbb4c862c881b4403b71bac9d4107d61fd1"},
{file = "protobuf-4.21.12-cp310-abi3-win_amd64.whl", hash = "sha256:89f9149e4a0169cddfc44c74f230d7743002e3aa0b9472d8c28f0388102fc4c2"},
{file = "protobuf-4.21.12-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:299ea899484ee6f44604deb71f424234f654606b983cb496ea2a53e3c63ab791"},
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:d1736130bce8cf131ac7957fa26880ca19227d4ad68b4888b3be0dea1f95df97"},
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:78a28c9fa223998472886c77042e9b9afb6fe4242bd2a2a5aced88e3f4422aa7"},
{file = "protobuf-4.21.12-cp37-cp37m-win32.whl", hash = "sha256:3d164928ff0727d97022957c2b849250ca0e64777ee31efd7d6de2e07c494717"},
{file = "protobuf-4.21.12-cp37-cp37m-win_amd64.whl", hash = "sha256:f45460f9ee70a0ec1b6694c6e4e348ad2019275680bd68a1d9314b8c7e01e574"},
{file = "protobuf-4.21.12-cp38-cp38-win32.whl", hash = "sha256:6ab80df09e3208f742c98443b6166bcb70d65f52cfeb67357d52032ea1ae9bec"},
{file = "protobuf-4.21.12-cp38-cp38-win_amd64.whl", hash = "sha256:1f22ac0ca65bb70a876060d96d914dae09ac98d114294f77584b0d2644fa9c30"},
{file = "protobuf-4.21.12-cp39-cp39-win32.whl", hash = "sha256:27f4d15021da6d2b706ddc3860fac0a5ddaba34ab679dc182b60a8bb4e1121cc"},
{file = "protobuf-4.21.12-cp39-cp39-win_amd64.whl", hash = "sha256:237216c3326d46808a9f7c26fd1bd4b20015fb6867dc5d263a493ef9a539293b"},
{file = "protobuf-4.21.12-py2.py3-none-any.whl", hash = "sha256:a53fd3f03e578553623272dc46ac2f189de23862e68565e83dde203d41b76fc5"},
{file = "protobuf-4.21.12-py3-none-any.whl", hash = "sha256:b98d0148f84e3a3c569e19f52103ca1feacdac0d2df8d6533cf983d1fda28462"},
{file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
{file = "protobuf-4.22.0-cp310-abi3-win32.whl", hash = "sha256:b2fea9dc8e3c0f32c38124790ef16cba2ee0628fe2022a52e435e1117bfef9b1"},
{file = "protobuf-4.22.0-cp310-abi3-win_amd64.whl", hash = "sha256:a33a273d21852f911b8bda47f39f4383fe7c061eb1814db2c76c9875c89c2491"},
{file = "protobuf-4.22.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:e894e9ae603e963f0842498c4cd5d39c6a60f0d7e4c103df50ee939564298658"},
{file = "protobuf-4.22.0-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:7c535d126e7dcc714105ab20b418c4fedbd28f8b8afc42b7350b1e317bbbcc71"},
{file = "protobuf-4.22.0-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:86c3d20428b007537ba6792b475c0853bba7f66b1f60e610d913b77d94b486e4"},
{file = "protobuf-4.22.0-cp37-cp37m-win32.whl", hash = "sha256:1669cb7524221a8e2d9008d0842453dbefdd0fcdd64d67672f657244867635fb"},
{file = "protobuf-4.22.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ab4d043865dd04e6b09386981fe8f80b39a1e46139fb4a3c206229d6b9f36ff6"},
{file = "protobuf-4.22.0-cp38-cp38-win32.whl", hash = "sha256:29288813aacaa302afa2381db1d6e0482165737b0afdf2811df5fa99185c457b"},
{file = "protobuf-4.22.0-cp38-cp38-win_amd64.whl", hash = "sha256:e474b63bab0a2ea32a7b26a4d8eec59e33e709321e5e16fb66e766b61b82a95e"},
{file = "protobuf-4.22.0-cp39-cp39-win32.whl", hash = "sha256:47d31bdf58222dd296976aa1646c68c6ee80b96d22e0a3c336c9174e253fd35e"},
{file = "protobuf-4.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:c27f371f0159feb70e6ea52ed7e768b3f3a4c5676c1900a7e51a24740381650e"},
{file = "protobuf-4.22.0-py3-none-any.whl", hash = "sha256:c3325803095fb4c2a48649c321d2fbde59f8fbfcb9bfc7a86df27d112831c571"},
{file = "protobuf-4.22.0.tar.gz", hash = "sha256:652d8dfece122a24d98eebfef30e31e455d300efa41999d1182e015984ac5930"},
]
psutil = [
{file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"},
@ -1089,8 +1117,8 @@ safetensors = [
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
]
setuptools = [
{file = "setuptools-67.2.0-py3-none-any.whl", hash = "sha256:16ccf598aab3b506593c17378473978908a2734d7336755a8769b480906bec1c"},
{file = "setuptools-67.2.0.tar.gz", hash = "sha256:b440ee5f7e607bb8c9de15259dba2583dd41a38879a7abc1d43a71c59524da48"},
{file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
{file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
]
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
@ -1124,8 +1152,8 @@ typer = [
{file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"},
]
typing-extensions = [
{file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"},
{file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"},
{file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"},
{file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"},
]
urllib3 = [
{file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"},
@ -1140,68 +1168,79 @@ win32-setctime = [
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
]
wrapt = [
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1"},
{file = "wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320"},
{file = "wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d"},
{file = "wrapt-1.14.1-cp35-cp35m-win32.whl", hash = "sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7"},
{file = "wrapt-1.14.1-cp35-cp35m-win_amd64.whl", hash = "sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00"},
{file = "wrapt-1.14.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569"},
{file = "wrapt-1.14.1-cp36-cp36m-win32.whl", hash = "sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed"},
{file = "wrapt-1.14.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471"},
{file = "wrapt-1.14.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a"},
{file = "wrapt-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853"},
{file = "wrapt-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c"},
{file = "wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456"},
{file = "wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57"},
{file = "wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5"},
{file = "wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d"},
{file = "wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383"},
{file = "wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe"},
{file = "wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5"},
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"},
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"},
{file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"},
{file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"},
{file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"},
{file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:96e25c8603a155559231c19c0349245eeb4ac0096fe3c1d0be5c47e075bd4f46"},
{file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:40737a081d7497efea35ab9304b829b857f21558acfc7b3272f908d33b0d9d4c"},
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:f87ec75864c37c4c6cb908d282e1969e79763e0d9becdfe9fe5473b7bb1e5f09"},
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1286eb30261894e4c70d124d44b7fd07825340869945c79d05bda53a40caa079"},
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:493d389a2b63c88ad56cdc35d0fa5752daac56ca755805b1b0c530f785767d5e"},
{file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:58d7a75d731e8c63614222bcb21dd992b4ab01a399f1f09dd82af17bbfc2368a"},
{file = "wrapt-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:21f6d9a0d5b3a207cdf7acf8e58d7d13d463e639f0c7e01d82cdb671e6cb7923"},
{file = "wrapt-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce42618f67741d4697684e501ef02f29e758a123aa2d669e2d964ff734ee00ee"},
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41d07d029dd4157ae27beab04d22b8e261eddfc6ecd64ff7000b10dc8b3a5727"},
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54accd4b8bc202966bafafd16e69da9d5640ff92389d33d28555c5fd4f25ccb7"},
{file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbfbca668dd15b744418265a9607baa970c347eefd0db6a518aaf0cfbd153c0"},
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:76e9c727a874b4856d11a32fb0b389afc61ce8aaf281ada613713ddeadd1cfec"},
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e20076a211cd6f9b44a6be58f7eeafa7ab5720eb796975d0c03f05b47d89eb90"},
{file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a74d56552ddbde46c246b5b89199cb3fd182f9c346c784e1a93e4dc3f5ec9975"},
{file = "wrapt-1.15.0-cp310-cp310-win32.whl", hash = "sha256:26458da5653aa5b3d8dc8b24192f574a58984c749401f98fff994d41d3f08da1"},
{file = "wrapt-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:75760a47c06b5974aa5e01949bf7e66d2af4d08cb8c1d6516af5e39595397f5e"},
{file = "wrapt-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba1711cda2d30634a7e452fc79eabcadaffedf241ff206db2ee93dd2c89a60e7"},
{file = "wrapt-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56374914b132c702aa9aa9959c550004b8847148f95e1b824772d453ac204a72"},
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a89ce3fd220ff144bd9d54da333ec0de0399b52c9ac3d2ce34b569cf1a5748fb"},
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bbe623731d03b186b3d6b0d6f51865bf598587c38d6f7b0be2e27414f7f214e"},
{file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3abbe948c3cbde2689370a262a8d04e32ec2dd4f27103669a45c6929bcdbfe7c"},
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b67b819628e3b748fd3c2192c15fb951f549d0f47c0449af0764d7647302fda3"},
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7eebcdbe3677e58dd4c0e03b4f2cfa346ed4049687d839adad68cc38bb559c92"},
{file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74934ebd71950e3db69960a7da29204f89624dde411afbfb3b4858c1409b1e98"},
{file = "wrapt-1.15.0-cp311-cp311-win32.whl", hash = "sha256:bd84395aab8e4d36263cd1b9308cd504f6cf713b7d6d3ce25ea55670baec5416"},
{file = "wrapt-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:a487f72a25904e2b4bbc0817ce7a8de94363bd7e79890510174da9d901c38705"},
{file = "wrapt-1.15.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:4ff0d20f2e670800d3ed2b220d40984162089a6e2c9646fdb09b85e6f9a8fc29"},
{file = "wrapt-1.15.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9ed6aa0726b9b60911f4aed8ec5b8dd7bf3491476015819f56473ffaef8959bd"},
{file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:896689fddba4f23ef7c718279e42f8834041a21342d95e56922e1c10c0cc7afb"},
{file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:75669d77bb2c071333417617a235324a1618dba66f82a750362eccbe5b61d248"},
{file = "wrapt-1.15.0-cp35-cp35m-win32.whl", hash = "sha256:fbec11614dba0424ca72f4e8ba3c420dba07b4a7c206c8c8e4e73f2e98f4c559"},
{file = "wrapt-1.15.0-cp35-cp35m-win_amd64.whl", hash = "sha256:fd69666217b62fa5d7c6aa88e507493a34dec4fa20c5bd925e4bc12fce586639"},
{file = "wrapt-1.15.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0724f05c396b0a4c36a3226c31648385deb6a65d8992644c12a4963c70326ba"},
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbeccb1aa40ab88cd29e6c7d8585582c99548f55f9b2581dfc5ba68c59a85752"},
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38adf7198f8f154502883242f9fe7333ab05a5b02de7d83aa2d88ea621f13364"},
{file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:578383d740457fa790fdf85e6d346fda1416a40549fe8db08e5e9bd281c6a475"},
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:a4cbb9ff5795cd66f0066bdf5947f170f5d63a9274f99bdbca02fd973adcf2a8"},
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:af5bd9ccb188f6a5fdda9f1f09d9f4c86cc8a539bd48a0bfdc97723970348418"},
{file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:b56d5519e470d3f2fe4aa7585f0632b060d532d0696c5bdfb5e8319e1d0f69a2"},
{file = "wrapt-1.15.0-cp36-cp36m-win32.whl", hash = "sha256:77d4c1b881076c3ba173484dfa53d3582c1c8ff1f914c6461ab70c8428b796c1"},
{file = "wrapt-1.15.0-cp36-cp36m-win_amd64.whl", hash = "sha256:077ff0d1f9d9e4ce6476c1a924a3332452c1406e59d90a2cf24aeb29eeac9420"},
{file = "wrapt-1.15.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5c5aa28df055697d7c37d2099a7bc09f559d5053c3349b1ad0c39000e611d317"},
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a8564f283394634a7a7054b7983e47dbf39c07712d7b177b37e03f2467a024e"},
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780c82a41dc493b62fc5884fb1d3a3b81106642c5c5c78d6a0d4cbe96d62ba7e"},
{file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e169e957c33576f47e21864cf3fc9ff47c223a4ebca8960079b8bd36cb014fd0"},
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b02f21c1e2074943312d03d243ac4388319f2456576b2c6023041c4d57cd7019"},
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f2e69b3ed24544b0d3dbe2c5c0ba5153ce50dcebb576fdc4696d52aa22db6034"},
{file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d787272ed958a05b2c86311d3a4135d3c2aeea4fc655705f074130aa57d71653"},
{file = "wrapt-1.15.0-cp37-cp37m-win32.whl", hash = "sha256:02fce1852f755f44f95af51f69d22e45080102e9d00258053b79367d07af39c0"},
{file = "wrapt-1.15.0-cp37-cp37m-win_amd64.whl", hash = "sha256:abd52a09d03adf9c763d706df707c343293d5d106aea53483e0ec8d9e310ad5e"},
{file = "wrapt-1.15.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdb4f085756c96a3af04e6eca7f08b1345e94b53af8921b25c72f096e704e145"},
{file = "wrapt-1.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:230ae493696a371f1dbffaad3dafbb742a4d27a0afd2b1aecebe52b740167e7f"},
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63424c681923b9f3bfbc5e3205aafe790904053d42ddcc08542181a30a7a51bd"},
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6bcbfc99f55655c3d93feb7ef3800bd5bbe963a755687cbf1f490a71fb7794b"},
{file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99f4309f5145b93eca6e35ac1a988f0dc0a7ccf9ccdcd78d3c0adf57224e62f"},
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b130fe77361d6771ecf5a219d8e0817d61b236b7d8b37cc045172e574ed219e6"},
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:96177eb5645b1c6985f5c11d03fc2dbda9ad24ec0f3a46dcce91445747e15094"},
{file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5fe3e099cf07d0fb5a1e23d399e5d4d1ca3e6dfcbe5c8570ccff3e9208274f7"},
{file = "wrapt-1.15.0-cp38-cp38-win32.whl", hash = "sha256:abd8f36c99512755b8456047b7be10372fca271bf1467a1caa88db991e7c421b"},
{file = "wrapt-1.15.0-cp38-cp38-win_amd64.whl", hash = "sha256:b06fa97478a5f478fb05e1980980a7cdf2712015493b44d0c87606c1513ed5b1"},
{file = "wrapt-1.15.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2e51de54d4fb8fb50d6ee8327f9828306a959ae394d3e01a1ba8b2f937747d86"},
{file = "wrapt-1.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0970ddb69bba00670e58955f8019bec4a42d1785db3faa043c33d81de2bf843c"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76407ab327158c510f44ded207e2f76b657303e17cb7a572ffe2f5a8a48aa04d"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd525e0e52a5ff16653a3fc9e3dd827981917d34996600bbc34c05d048ca35cc"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d37ac69edc5614b90516807de32d08cb8e7b12260a285ee330955604ed9dd29"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:078e2a1a86544e644a68422f881c48b84fef6d18f8c7a957ffd3f2e0a74a0d4a"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2cf56d0e237280baed46f0b5316661da892565ff58309d4d2ed7dba763d984b8"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7dc0713bf81287a00516ef43137273b23ee414fe41a3c14be10dd95ed98a2df9"},
{file = "wrapt-1.15.0-cp39-cp39-win32.whl", hash = "sha256:46ed616d5fb42f98630ed70c3529541408166c22cdfd4540b88d5f21006b0eff"},
{file = "wrapt-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:eef4d64c650f33347c1f9266fa5ae001440b232ad9b98f1f43dfe7a79435c0a6"},
{file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"},
{file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"},
]

View File

@ -1,11 +1,11 @@
[tool.poetry]
name = "text-generation"
version = "0.2.1"
name = "text-generation-server"
version = "0.4.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app'
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
@ -22,6 +22,7 @@ loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
[tool.poetry.extras]
bnb = ["bitsandbytes"]

View File

@ -1,6 +1,6 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation_server.pb import generate_pb2
@pytest.fixture
@ -10,6 +10,7 @@ def default_pb_parameters():
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
)

View File

@ -4,9 +4,9 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")
@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -65,8 +64,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
@ -98,22 +97,19 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert (
len(next_batch.all_input_ids[0])
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
@ -213,15 +209,19 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -4,8 +4,8 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -62,8 +61,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
@ -94,23 +93,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert (
len(next_batch.all_input_ids[0])
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all(
@ -210,16 +206,20 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert torch.all(
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
)
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -1,8 +1,8 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session")
@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="def",
input_length=1,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_length=5,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")
@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=2,
parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters,
)
@ -106,7 +105,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
assert next_batch.input_ids is None
assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
)
@ -148,7 +147,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations])
assert all([generation.token_text == " " for generation in generations])
assert generations[0].request_id == 0
@ -220,11 +219,6 @@ def test_batch_concatenate(
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids[:, 0] == 4268)
assert torch.all(next_batch.input_ids[:, 1] == 1)
assert torch.all(next_batch.attention_mask == 1)
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
@ -233,9 +227,10 @@ def test_batch_concatenate(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal(
next_batch.encoder_last_hidden_state[0],

View File

@ -0,0 +1,21 @@
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,40 @@
import pytest
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,14 +1,6 @@
import pytest
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
LocalEntryNotFoundError,
FinishReason,
)
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,68 +0,0 @@
import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
):
utils.download_weights(model_id, revision, extension)
if __name__ == "__main__":
app()

View File

@ -1,24 +0,0 @@
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.device = device
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError

View File

@ -1,283 +0,0 @@
import concurrent
import os
import re
import torch
import torch.distributed
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)

View File

@ -0,0 +1,115 @@
import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
logger_level: str = "INFO",
json_output: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__":
app()

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig
from typing import Optional
from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
__all__ = [
"Model",
@ -19,7 +19,6 @@ __all__ = [
"CausalLM",
"Galactica",
"GalacticaSharded",
"GPTNeox",
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
@ -41,6 +40,15 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom":
@ -48,27 +56,22 @@ def get_model(
return BLOOMSharded(model_id, revision, quantize=quantize)
else:
return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if config.model_type == "gpt_neox":
if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize)
else:
return GPTNeox(model_id, revision, quantize=quantize)
elif config.model_type == "t5":
return CausalLM(model_id, revision, quantize=quantize)
if config.model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
elif model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
try:
return CausalLM(model_id, revision, quantize=quantize)
except Exception:
return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -17,13 +17,12 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -59,9 +58,6 @@ class BLOOMSharded(BLOOM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
@ -80,14 +76,8 @@ class BLOOMSharded(BLOOM):
)
config.pad_token_id = 3
# Only download weights for small models
if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@ -36,7 +41,8 @@ class CausalLMBatch(Batch):
# Metadata used for padding
size: int
max_sequence_length: int
max_input_length: int
padding_right_offset: int
# Past metadata
keys_head_dim_last: bool = True
@ -61,22 +67,36 @@ class CausalLMBatch(Batch):
input_lengths = []
# Parse batch
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -84,24 +104,30 @@ class CausalLMBatch(Batch):
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
input_lengths=input_lengths.tolist(),
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(input_lengths),
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_sequence_length = max(batch.max_sequence_length for batch in batches)
total_batch_size = 0
max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
@ -144,13 +170,24 @@ class CausalLMBatch(Batch):
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length),
(total_batch_size, max_input_length + padding_right_offset),
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index, -batch.max_sequence_length :
] = batch.attention_mask[:, -batch.max_sequence_length :]
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
@ -172,7 +209,7 @@ class CausalLMBatch(Batch):
padded_past_values_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
max_input_length - 1,
head_dim,
)
@ -184,7 +221,7 @@ class CausalLMBatch(Batch):
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
max_input_length - 1,
)
# This will run only once per layer
@ -198,20 +235,20 @@ class CausalLMBatch(Batch):
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
-(batch.max_input_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
-(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size
@ -227,7 +264,8 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
)
@ -294,9 +332,12 @@ class CausalLM(Model):
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
attention_mask,
batch.position_ids,
batch.past_key_values,
)
@ -311,7 +352,7 @@ class CausalLM(Model):
# Metadata
next_batch_size = 0
next_batch_max_sequence_length = 0
next_batch_max_input_length = 0
# Results
generations: List[Generation] = []
@ -347,10 +388,8 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_text = self.decode_token(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
@ -381,8 +420,8 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
next_batch_max_input_length = max(
next_batch_max_input_length, new_input_length
)
# Prefill
@ -409,6 +448,7 @@ class CausalLM(Model):
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
@ -448,14 +488,8 @@ class CausalLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_attention_mask = torch.cat(
[
next_batch_attention_mask,
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
# Update attention_mask as we added a new token to input_ids
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
# Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
@ -472,7 +506,8 @@ class CausalLM(Model):
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generations, next_batch

View File

@ -2,7 +2,7 @@ import re
import torch
import torch.distributed
from typing import List, Optional, Type
from typing import List, Optional, Type, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
@ -18,15 +18,14 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -97,24 +96,37 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths = []
# Parse batch
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -122,8 +134,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
all_input_ids=all_input_ids,
@ -131,7 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(input_lengths),
max_sequence_length=max_sequence_length,
padding_right_offset=padding_right_offset,
)
@ -146,14 +159,25 @@ class Galactica(CausalLM):
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
@ -172,14 +196,8 @@ class GalacticaSharded(Galactica):
)
tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master and model_id == "facebook/galactica-125m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
@ -329,7 +347,6 @@ class GalacticaSharded(Galactica):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from typing import List, Optional
from accelerate import init_empty_weights
from safetensors import safe_open
@ -16,11 +16,10 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -31,23 +30,7 @@ except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeox(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GPTNeoxSharded(GPTNeox):
class GPTNeoxSharded(CausalLM):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
@ -69,14 +52,8 @@ class GPTNeoxSharded(GPTNeox):
model_id, revision=revision, tp_parallel=True
)
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
@ -231,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

View File

@ -0,0 +1,43 @@
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device
# see `decode_token` method
self.tokenizer.add_special_tokens(
{"additional_special_tokens": ["<decode-token>"]}
)
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
"<decode-token>"
)
self.special_decode_token_length = len("<decode-token>")
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def decode_token(self, token_id: int) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both
result = self.tokenizer.decode(
[self.special_decode_token_id, token_id], skip_special_tokens=False
)
# slice to remove special decode token
return result[self.special_decode_token_length :]

View File

@ -1,10 +1,10 @@
import torch
import torch.distributed
from typing import Optional, List, Tuple
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@ -42,9 +47,10 @@ class Seq2SeqLMBatch(Batch):
size: int
max_input_length: int
max_decoder_input_length: int
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
@ -58,36 +64,41 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []
input_lengths = []
decoder_input_ids = []
decoder_input_lengths = []
# Parse batch
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -100,13 +111,14 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
input_lengths=input_lengths,
input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=len(pb.requests),
max_input_length=max(input_lengths),
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
)
@classmethod
@ -115,11 +127,17 @@ class Seq2SeqLMBatch(Batch):
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_input_length = max(batch.max_input_length for batch in batches)
max_decoder_input_length = max(
batch.max_decoder_input_length for batch in batches
)
total_batch_size = 0
max_input_length = 0
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length
)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
@ -129,7 +147,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = []
# Batch tensors
input_ids = None
attention_mask = None
decoder_input_ids = None
decoder_attention_mask = None
@ -155,16 +172,6 @@ class Seq2SeqLMBatch(Batch):
if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if input_ids is None:
input_ids = batch.input_ids.new_zeros(
(total_batch_size, max_input_length),
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
@ -189,19 +196,30 @@ class Seq2SeqLMBatch(Batch):
if decoder_attention_mask is None:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length),
(total_batch_size, max_decoder_input_length + padding_right_offset),
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset = max_decoder_input_length - batch.max_decoder_input_length
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
start_index:end_index,
left_offset:-padding_right_offset,
] = 1
# If it exists, we need to index
else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length
- batch.padding_right_offset
)
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.decoder_attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create padded tensor
if encoder_last_hidden_state is None:
@ -273,7 +291,7 @@ class Seq2SeqLMBatch(Batch):
return cls(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
@ -286,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
size=total_batch_size,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
)
def __len__(self):
@ -326,7 +345,9 @@ class Seq2SeqLM(Model):
return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self,
@ -342,14 +363,6 @@ class Seq2SeqLM(Model):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
@ -369,12 +382,34 @@ class Seq2SeqLM(Model):
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
:, : -batch.padding_right_offset
]
else:
decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = batch.encoder_last_hidden_state
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
batch.past_key_values,
)
@ -402,7 +437,6 @@ class Seq2SeqLM(Model):
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.input_ids,
batch.decoder_input_ids,
)
@ -414,7 +448,6 @@ class Seq2SeqLM(Model):
logits,
next_token_chooser,
stopping_criteria,
input_tokens,
decoder_input_ids,
) in enumerate(iterator):
# Select next token
@ -429,10 +462,8 @@ class Seq2SeqLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode(
next_token_text = self.decode_token(
next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
# Evaluate stopping criteria
@ -469,14 +500,10 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
)
else:
prefill_tokens = None
@ -487,6 +514,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
)
@ -500,10 +528,8 @@ class Seq2SeqLM(Model):
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
# Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices
@ -526,7 +552,6 @@ class Seq2SeqLM(Model):
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids = batch.input_ids
next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state
@ -536,20 +561,14 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update decoder_attention_mask with padding as we added a new token to input_ids
# Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat(
[
next_batch_decoder_attention_mask,
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
input_ids=None,
attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask,
@ -562,5 +581,6 @@ class Seq2SeqLM(Model):
size=next_batch_size,
max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
)
return generations, next_batch

View File

@ -16,11 +16,10 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import Seq2SeqLM
from text_generation.utils import (
from text_generation_server.models import Seq2SeqLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
)
tokenizer.bos_token_id = config.decoder_start_token_id
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
@ -228,14 +221,6 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):
@ -73,6 +73,7 @@ class Generation:
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation:
@ -84,6 +85,7 @@ class Generation:
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -0,0 +1,36 @@
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
]

View File

@ -0,0 +1,94 @@
import concurrent
import time
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
logger.info(f"Convert {pt_file} to {st_file}.")
pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state:
pt_state = pt_state["state_dict"]
remove_shared_pointers(pt_state)
# Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"})
# Check that both files are close in size
check_file_size(pt_file, st_file)
# Load safetensors state
st_state = load_file(str(st_file))
for k in st_state:
pt_tensor = pt_state[k]
st_tensor = st_state[k]
if not torch.equal(pt_tensor, st_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]):
assert len(pt_files) == len(st_files)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -0,0 +1,35 @@
import os
import torch
from datetime import timedelta
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size

View File

@ -0,0 +1,165 @@
import time
import os
from datetime import timedelta
from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return None
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return cached_file if cached_file.is_file() else None
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
return list(Path(model_id).glob(f"*{extension}"))
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise LocalEntryNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(filename):
local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return Path(local_file)
logger.info(f"Download file: {filename}")
start_time = time.time()
local_file = hf_hub_download(
filename=filename,
repo_id=model_id,
revision=revision,
local_files_only=False,
)
logger.info(
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return Path(local_file)
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
files = []
for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(file)
return files

View File

@ -0,0 +1,160 @@
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if watermark:
warpers.append(WatermarkLogitsProcessor(device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)

View File

@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from transformers import LogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.gamma = gamma
self.delta = delta
self.rng = torch.Generator(device=device)
self.hash_key = hash_key
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
assert (
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int
) -> list[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(
max_value, device=input_ids.device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@staticmethod
def _calc_greenlist_mask(
scores: torch.FloatTensor, greenlist_token_ids
) -> torch.BoolTensor:
green_tokens_mask = torch.zeros_like(scores)
green_tokens_mask[-1, greenlist_token_ids] = 1
final_mask = green_tokens_mask.bool()
return final_mask
@staticmethod
def _bias_greenlist_logits(
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
) -> torch.Tensor:
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
return scores
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)
scores = self._bias_greenlist_logits(
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
)
return scores

9
supported_models.json Normal file
View File

@ -0,0 +1,9 @@
[
"bigscience/bloom",
"bigscience/bloomz",
"EleutherAI/gpt-neox-20b",
"google/flan-ul2",
"google/flan-t5-xxl",
"OpenAssistant/oasst-sft-1-pythia-12b",
"olivierdehaene/optimized-santacoder"
]