Merge branch 'main' into lewtun-patch-1

This commit is contained in:
OlivierDehaene 2023-03-23 18:03:33 +01:00 committed by GitHub
commit c07acd4fea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 5799 additions and 1414 deletions

View File

@ -8,6 +8,15 @@ on:
tags: tags:
- 'v*' - 'v*'
pull_request: pull_request:
paths:
- ".github/workflows/build.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
- "Dockerfile"
branches: branches:
- 'main' - 'main'
@ -15,6 +24,10 @@ jobs:
build-and-push-image: build-and-push-image:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v2.0.0
with:
install: true
- name: Tailscale - name: Tailscale
uses: tailscale/github-action@v1 uses: tailscale/github-action@v1
with: with:
@ -65,5 +78,5 @@ jobs:
platforms: 'linux/amd64' platforms: 'linux/amd64'
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/huggingface/text-generation-inference:latest cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max
cache-to: type=inline cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max

View File

@ -3,14 +3,23 @@ name: Server Tests
on: on:
pull_request: pull_request:
paths: paths:
- ".github/workflows/tests.yaml"
- "server/**" - "server/**"
- "proto/**" - "proto/**"
- "router/**" - "router/**"
- "launcher/**" - "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
jobs: jobs:
run_tests: run_tests:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python - name: Set up Python
@ -25,19 +34,38 @@ jobs:
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc
uses: arduino/setup-protoc@v1 uses: arduino/setup-protoc@v1
- name: Loading cache. - name: Install sccache
uses: actions/cache@v2 run: |
id: model_cache curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with: with:
path: ~/.cache/huggingface/ script: |
key: models core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install - name: Install
run: | run: |
make install make install
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
pytest -sv server/tests HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
- name: Run Rust tests - name: Run Rust tests
run: | run: |
cargo test cargo test
- name: sccache stats
run: |
/usr/local/bin/sccache --show-stats

317
Cargo.lock generated
View File

@ -8,6 +8,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.20" version = "0.7.20"
@ -34,19 +45,20 @@ checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800"
[[package]] [[package]]
name = "async-stream" name = "async-stream"
version = "0.3.3" version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" checksum = "ad445822218ce64be7a341abfb0b1ea43b5c23aa83902542a4542e78309d8e5e"
dependencies = [ dependencies = [
"async-stream-impl", "async-stream-impl",
"futures-core", "futures-core",
"pin-project-lite",
] ]
[[package]] [[package]]
name = "async-stream-impl" name = "async-stream-impl"
version = "0.3.3" version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" checksum = "e4655ae1a7b0cdf149156f780c5bf3f1352bc53cbd9e0a361a7ef7b22947e965"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -83,9 +95,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.6.4" version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc" checksum = "6137c6234afb339e75e764c866e3594900f0211e1315d33779f269bbe2ec6967"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -109,7 +121,7 @@ dependencies = [
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tower", "tower",
"tower-http", "tower-http 0.4.0",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
] ]
@ -142,7 +154,7 @@ dependencies = [
"http", "http",
"opentelemetry", "opentelemetry",
"tower", "tower",
"tower-http", "tower-http 0.3.5",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
] ]
@ -265,9 +277,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.1.4" version = "4.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f13b9c79b5d1dd500d20ef541215a6423c75829ef43117e1b4d17fd8af0b5d76" checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"clap_derive", "clap_derive",
@ -280,9 +292,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.1.0" version = "4.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "684a277d672e91966334af371f1a7b5833f9aa00b07c84e92fbce95e00208ce8" checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0"
dependencies = [ dependencies = [
"heck", "heck",
"proc-macro-error", "proc-macro-error",
@ -293,9 +305,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_lex" name = "clap_lex"
version = "0.3.1" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade" checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09"
dependencies = [ dependencies = [
"os_str_bytes", "os_str_bytes",
] ]
@ -349,9 +361,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.6" version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"crossbeam-utils", "crossbeam-utils",
@ -359,9 +371,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-deque" name = "crossbeam-deque"
version = "0.8.2" version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"crossbeam-epoch", "crossbeam-epoch",
@ -370,9 +382,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-epoch" name = "crossbeam-epoch"
version = "0.9.13" version = "0.9.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"cfg-if", "cfg-if",
@ -383,9 +395,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.14" version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
] ]
@ -575,9 +587,9 @@ dependencies = [
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "1.8.0" version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be"
dependencies = [ dependencies = [
"instant", "instant",
] ]
@ -774,7 +786,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]] [[package]]
name = "grpc-metadata" name = "grpc-metadata"
version = "0.1.0" version = "0.4.0"
dependencies = [ dependencies = [
"opentelemetry", "opentelemetry",
"tonic", "tonic",
@ -784,9 +796,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.15" version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -806,6 +818,9 @@ name = "hashbrown"
version = "0.12.3" version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "heck" name = "heck"
@ -839,9 +854,9 @@ checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.8" version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -1004,9 +1019,9 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.3" version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0a45d56fe973d6db23972bf5bc46f988a4a2385deac9cc29572f09daef" checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857"
dependencies = [ dependencies = [
"hermit-abi 0.3.1", "hermit-abi 0.3.1",
"io-lifetimes", "io-lifetimes",
@ -1093,6 +1108,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "mach"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "macro_rules_attribute" name = "macro_rules_attribute"
version = "0.1.3" version = "0.1.3"
@ -1132,13 +1156,71 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.7.1" version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
dependencies = [ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "metrics"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]]
name = "metrics-exporter-prometheus"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
dependencies = [
"hyper",
"indexmap",
"ipnet",
"metrics",
"metrics-util",
"parking_lot",
"portable-atomic",
"quanta",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "metrics-macros"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "metrics-util"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"hashbrown",
"metrics",
"num_cpus",
"parking_lot",
"portable-atomic",
"quanta",
"sketches-ddsketch",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.16" version = "0.3.16"
@ -1172,14 +1254,14 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.5" version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.42.0", "windows-sys 0.45.0",
] ]
[[package]] [[package]]
@ -1268,9 +1350,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.17.0" version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]] [[package]]
name = "onig" name = "onig"
@ -1514,6 +1596,12 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "portable-atomic"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.17" version = "0.2.17"
@ -1565,9 +1653,9 @@ dependencies = [
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.11.6" version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698" checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive",
@ -1575,9 +1663,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-build" name = "prost-build"
version = "0.11.6" version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e" checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck", "heck",
@ -1597,9 +1685,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.11.6" version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d" checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.10.5", "itertools 0.10.5",
@ -1610,14 +1698,29 @@ dependencies = [
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.11.6" version = "0.11.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788" checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88"
dependencies = [ dependencies = [
"bytes",
"prost", "prost",
] ]
[[package]]
name = "quanta"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
dependencies = [
"crossbeam-utils",
"libc",
"mach",
"once_cell",
"raw-cpuid",
"wasi 0.10.2+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.23" version = "1.0.23"
@ -1657,6 +1760,15 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.6.1" version = "1.6.1"
@ -1736,15 +1848,6 @@ version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.11.14" version = "0.11.14"
@ -1973,18 +2076,24 @@ dependencies = [
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.0" version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
dependencies = [ dependencies = [
"libc", "libc",
] ]
[[package]] [[package]]
name = "slab" name = "sketches-ddsketch"
version = "0.4.7" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
[[package]]
name = "slab"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
dependencies = [ dependencies = [
"autocfg", "autocfg",
] ]
@ -1997,9 +2106,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.4.7" version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" checksum = "95a21dcece9b5991cfd1ece74654c8e3d0d5aab499d359b0395e38229c0bb5a3"
dependencies = [ dependencies = [
"libc", "libc",
"winapi", "winapi",
@ -2053,9 +2162,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.107" version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2081,16 +2190,15 @@ dependencies = [
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.3.0" version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
"libc",
"redox_syscall", "redox_syscall",
"remove_dir_all", "rustix",
"winapi", "windows-sys 0.42.0",
] ]
[[package]] [[package]]
@ -2104,7 +2212,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.2.1" version = "0.4.0"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2121,9 +2229,9 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.2.1" version = "0.4.0"
dependencies = [ dependencies = [
"clap 4.1.4", "clap 4.1.8",
"ctrlc", "ctrlc",
"float_eq", "float_eq",
"reqwest", "reqwest",
@ -2136,18 +2244,21 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.2.1" version = "0.4.0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap 4.1.4", "clap 4.1.8",
"futures", "futures",
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher", "nohash-hasher",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"parking_lot", "parking_lot",
"rand", "rand",
"reqwest",
"serde", "serde",
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
@ -2155,6 +2266,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http 0.3.5",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",
@ -2193,9 +2305,9 @@ dependencies = [
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.6" version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f297120ff9d4efe680df143d5631bba9c75fa371992b7fcb33eb3453cb0a07" checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"once_cell", "once_cell",
@ -2203,12 +2315,11 @@ dependencies = [
[[package]] [[package]]
name = "time" name = "time"
version = "0.1.45" version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
dependencies = [ dependencies = [
"libc", "libc",
"wasi 0.10.0+wasi-snapshot-preview1",
"winapi", "winapi",
] ]
@ -2264,9 +2375,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.25.0" version = "1.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"bytes", "bytes",
@ -2279,7 +2390,7 @@ dependencies = [
"signal-hook-registry", "signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys 0.42.0", "windows-sys 0.45.0",
] ]
[[package]] [[package]]
@ -2315,9 +2426,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.11" version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",
@ -2326,9 +2437,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.6" version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6a3b08b64e6dfad376fa2432c7b1f01522e37a623c3050bc95db2d3ff21583" checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-core", "futures-core",
@ -2417,12 +2528,30 @@ dependencies = [
"http-body", "http-body",
"http-range-header", "http-range-header",
"pin-project-lite", "pin-project-lite",
"tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.2"
@ -2627,9 +2756,9 @@ dependencies = [
[[package]] [[package]]
name = "utoipa" name = "utoipa"
version = "3.0.1" version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c" checksum = "a15f6da6a2b471134ca44b7d18e8a76d73035cf8b3ed24c4dd5ca6a63aa439c5"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"serde", "serde",
@ -2639,9 +2768,9 @@ dependencies = [
[[package]] [[package]]
name = "utoipa-gen" name = "utoipa-gen"
version = "3.0.1" version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062" checksum = "6f2e33027986a4707b3f5c37ed01b33d0e5a53da30204b52ff18f80600f1d0ec"
dependencies = [ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
@ -2712,9 +2841,9 @@ dependencies = [
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.10.0+wasi-snapshot-preview1" version = "0.10.2+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]] [[package]]
name = "wasi" name = "wasi"

View File

@ -1,4 +1,15 @@
FROM rust:1.67 as router-builder FROM lukemathwalker/cargo-chef:latest-rust-1.67 AS chef
WORKDIR /usr/src
FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
@ -6,26 +17,15 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
WORKDIR /usr/src COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto COPY proto proto
COPY router router COPY router router
WORKDIR /usr/src/router
RUN cargo install --path .
FROM rust:1.67 as launcher-builder
WORKDIR /usr/src
COPY rust-toolchain.toml rust-toolchain.toml
COPY launcher launcher COPY launcher launcher
RUN cargo build --release
WORKDIR /usr/src/launcher
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \ DEBIAN_FRONTEND=noninteractive \
HUGGINGFACE_HUB_CACHE=/data \ HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \ MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \ QUANTIZE=false \
NUM_SHARD=1 \ NUM_SHARD=1 \
@ -68,9 +69,9 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir /opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
# Install router # Install router
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -13,25 +13,25 @@ server-dev:
cd server && make run-dev cd server && make run-dev
router-dev: router-dev:
cd router && cargo run cd router && cargo run -- --port 8080
integration-tests: install-router install-launcher integration-tests: install-router install-launcher
cargo test cargo test
python-tests: python-tests:
cd server && pytest tests cd server && HF_HUB_ENABLE_HF_TRANSFER=1 pytest tests
run-bloom-560m: run-bloom-560m:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --port 8080
run-bloom-560m-quantize: run-bloom-560m-quantize:
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize --port 8080
download-bloom: download-bloom:
text-generation-server download-weights bigscience/bloom HF_HUB_ENABLE_HF_TRANSFER=1 text-generation-server download-weights bigscience/bloom
run-bloom: run-bloom:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
run-bloom-quantize: run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080

View File

@ -39,27 +39,30 @@ to power LLMs api-inference widgets.
## Features ## Features
- Serve the most popular Large Language Models with a simple launcher
- Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE) - Token streaming using Server-Sent Events (SSE)
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput - [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warpers (temperature scaling, topk, repetition penalty ...) - Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities
- Distributed tracing with Open Telemetry - Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
## Officially supported models ## Officially supported architectures
- [BLOOM](https://huggingface.co/bigscience/bloom) - [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz) - [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [Galactica](https://huggingface.co/facebook/galactica-120b)
- [SantaCoder](https://huggingface.co/bigcode/santacoder) - [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) - [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl)
- [FLAN-UL2](https://huggingface.co/google/flan-ul2)
Other models are supported on a best effort basis using: Other architectures are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")` `AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
@ -80,24 +83,42 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
``` ```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
You can then query the model using either the `/generate` or `/generate_stream` routes: You can then query the model using either the `/generate` or `/generate_stream` routes:
```shell ```shell
curl 127.0.0.1:8080/generate \ curl 127.0.0.1:8080/generate \
-X POST \ -X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
```shell ```shell
curl 127.0.0.1:8080/generate_stream \ curl 127.0.0.1:8080/generate_stream \
-X POST \ -X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. or from Python:
```shell
pip install text-generation
```
```python
from text_generation import Client
client = Client("http://127.0.0.1:8080")
print(client.generate("What is Deep Learning?", max_new_tokens=17).generated_text)
text = ""
for response in client.generate_stream("What is Deep Learning?", max_new_tokens=17):
if not response.token.special:
text += response.token.text
print(text)
```
### API documentation ### API documentation
@ -191,7 +212,7 @@ Be aware that the official Docker image has them enabled by default.
### Download ### Download
First you need to download the weights: It is advised to download the weights ahead of time with the following command:
```shell ```shell
make download-bloom make download-bloom

158
clients/python/.gitignore vendored Normal file
View File

@ -0,0 +1,158 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors

6
clients/python/Makefile Normal file
View File

@ -0,0 +1,6 @@
unit-tests:
python -m pytest --cov=text_generation tests
install:
pip install pip --upgrade
pip install -e .

196
clients/python/README.md Normal file
View File

@ -0,0 +1,196 @@
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
## Get Started
### Install
```shell
pip install text-generation
```
### Inference API Usage
```python
from text_generation import InferenceAPIClient
client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Hugging Face Inference Endpoint usage
```python
from text_generation import Client
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import AsyncClient
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Types
```python
# Prompt tokens
class PrefillToken:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response:
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse:
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
```

1038
clients/python/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
[tool.poetry]
name = "text-generation"
version = "0.3.1"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
maintainers = ["Olivier Dehaene <olivier@huggingface.co>"]
readme = "README.md"
homepage = "https://github.com/huggingface/text-generation-inference"
repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
python = "^3.7"
pydantic = "^1.10"
aiohttp = "^3.8"
huggingface-hub = ">= 0.12, < 1.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
pytest-asyncio = "^0.17.2"
pytest-cov = "^3.0.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -0,0 +1,51 @@
import pytest
from text_generation import __version__
from huggingface_hub.utils import build_hf_headers
@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"
@pytest.fixture
def fake_model():
return "fake/model"
@pytest.fixture
def unsupported_model():
return "gpt2"
@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"
@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"
@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"
@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"
@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="text-generation-tests", library_version=__version__
)

View File

@ -0,0 +1,133 @@
import pytest
from text_generation import Client, AsyncClient
from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, PrefillToken, Token
def test_generate(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False
)
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
assert response.details.seed is not None
assert response.details.best_of_sequences is not None
assert len(response.details.best_of_sequences) == 1
assert response.details.best_of_sequences[0].seed is not None
def test_generate_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000)
def test_generate_stream(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
responses = [
response for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
def test_generate_stream_not_found(fake_url, hf_headers):
client = Client(fake_url, hf_headers)
with pytest.raises(NotFoundError):
list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
response = await client.generate("test", max_new_tokens=1)
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) == 1
assert response.details.tokens[0] == Token(
id=3, text=" ", logprob=-1.984375, special=False
)
@pytest.mark.asyncio
async def test_generate_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
await client.generate("test")
@pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
responses = [
response async for response in client.generate_stream("test", max_new_tokens=1)
]
assert len(responses) == 1
response = responses[0]
assert response.generated_text == ""
assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1
assert response.details.seed is None
@pytest.mark.asyncio
async def test_generate_stream_async_not_found(fake_url, hf_headers):
client = AsyncClient(fake_url, hf_headers)
with pytest.raises(NotFoundError):
async for _ in client.generate_stream("test"):
pass
@pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers)
with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass

View File

@ -0,0 +1,64 @@
from text_generation.errors import (
parse_error,
GenerationError,
IncompleteGenerationError,
OverloadedError,
ValidationError,
BadRequestError,
ShardNotReadyError,
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
)
def test_generation_error():
payload = {"error_type": "generation", "error": "test"}
assert isinstance(parse_error(400, payload), GenerationError)
def test_incomplete_generation_error():
payload = {"error_type": "incomplete_generation", "error": "test"}
assert isinstance(parse_error(400, payload), IncompleteGenerationError)
def test_overloaded_error():
payload = {"error_type": "overloaded", "error": "test"}
assert isinstance(parse_error(400, payload), OverloadedError)
def test_validation_error():
payload = {"error_type": "validation", "error": "test"}
assert isinstance(parse_error(400, payload), ValidationError)
def test_bad_request_error():
payload = {"error": "test"}
assert isinstance(parse_error(400, payload), BadRequestError)
def test_shard_not_ready_error():
payload = {"error": "test"}
assert isinstance(parse_error(403, payload), ShardNotReadyError)
assert isinstance(parse_error(424, payload), ShardNotReadyError)
def test_shard_timeout_error():
payload = {"error": "test"}
assert isinstance(parse_error(504, payload), ShardTimeoutError)
def test_not_found_error():
payload = {"error": "test"}
assert isinstance(parse_error(404, payload), NotFoundError)
def test_rate_limit_exceeded_error():
payload = {"error": "test"}
assert isinstance(parse_error(429, payload), RateLimitExceededError)
def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)

View File

@ -0,0 +1,34 @@
import pytest
from text_generation import (
InferenceAPIClient,
InferenceAPIAsyncClient,
Client,
AsyncClient,
)
from text_generation.errors import NotSupportedError
from text_generation.inference_api import get_supported_models
def test_get_supported_models():
assert isinstance(get_supported_models(), list)
def test_client(flan_t5_xxl):
client = InferenceAPIClient(flan_t5_xxl)
assert isinstance(client, Client)
def test_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIClient(unsupported_model)
def test_async_client(flan_t5_xxl):
client = InferenceAPIAsyncClient(flan_t5_xxl)
assert isinstance(client, AsyncClient)
def test_async_client_unsupported_model(unsupported_model):
with pytest.raises(NotSupportedError):
InferenceAPIAsyncClient(unsupported_model)

View File

@ -0,0 +1,82 @@
import pytest
from text_generation.types import Parameters, Request
from text_generation.errors import ValidationError
def test_parameters_validation():
# Test best_of
Parameters(best_of=1)
with pytest.raises(ValidationError):
Parameters(best_of=0)
with pytest.raises(ValidationError):
Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
with pytest.raises(ValidationError):
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=0.5)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)
with pytest.raises(ValidationError):
Parameters(top_p=1)
# Test truncate
Parameters(truncate=1)
with pytest.raises(ValidationError):
Parameters(truncate=0)
with pytest.raises(ValidationError):
Parameters(truncate=-1)
# Test typical_p
Parameters(typical_p=0.5)
with pytest.raises(ValidationError):
Parameters(typical_p=0)
with pytest.raises(ValidationError):
Parameters(typical_p=-1)
with pytest.raises(ValidationError):
Parameters(typical_p=1)
def test_request_validation():
Request(inputs="test")
with pytest.raises(ValidationError):
Request(inputs="")
Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)

View File

@ -0,0 +1,18 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.3.0"
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

View File

@ -0,0 +1,487 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
"""Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
"""
Given a prompt, generate the following text
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Iterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
"""Asynchronous Client to make calls to a text-generation-inference instance
Example:
```python
>>> from text_generation import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
):
"""
Args:
base_url (`str`):
text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`):
Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`):
Timeout in seconds
"""
self.base_url = base_url
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Args:
prompt (`str`):
Input text
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
return_full_text (`bool`):
Whether to prepend the prompt to the generated text
seed (`int`):
Random sampling seed
stop_sequences (`List[str]`):
Stop generating tokens if a member of `stop_sequences` is generated
temperature (`float`):
The value used to module the logits distribution.
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop_sequences if stop_sequences is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
# Parse ServerSentEvents
async for byte_payload in resp.content:
# Skip line
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
# Event data
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
try:
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response

View File

@ -0,0 +1,106 @@
from typing import Dict
# Text Generation Inference Errors
class ValidationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class GenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
class OverloadedError(Exception):
def __init__(self, message: str):
super().__init__(message)
class IncompleteGenerationError(Exception):
def __init__(self, message: str):
super().__init__(message)
# API Inference Errors
class BadRequestError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardNotReadyError(Exception):
def __init__(self, message: str):
super().__init__(message)
class ShardTimeoutError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotFoundError(Exception):
def __init__(self, message: str):
super().__init__(message)
class RateLimitExceededError(Exception):
def __init__(self, message: str):
super().__init__(message)
class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
super(NotSupportedError, self).__init__(message)
# Unknown error
class UnknownError(Exception):
def __init__(self, message: str):
super().__init__(message)
def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Args:
status_code (`int`):
HTTP status code
payload (`Dict[str, str]`):
Json payload
Returns:
Exception: parsed exception
"""
# Try to parse a Text Generation Inference error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
# Try to parse a APIInference error
if status_code == 400:
return BadRequestError(message)
if status_code == 403 or status_code == 424:
return ShardNotReadyError(message)
if status_code == 504:
return ShardTimeoutError(message)
if status_code == 404:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
# Fallback to an unknown error
return UnknownError(message)

View File

@ -0,0 +1,154 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
"""
Get the list of supported text-generation models from GitHub
Returns:
Optional[List[str]]: supported models list or None if unable to get the list from GitHub
"""
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class InferenceAPIClient(Client):
"""Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIClient
>>> client = InferenceAPIClient("bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(
base_url, headers=headers, timeout=timeout
)
class InferenceAPIAsyncClient(AsyncClient):
"""Aynschronous Client to make calls to the HuggingFace Inference API.
Only supports a subset of the available text-generation or text2text-generation models that are served using
text-generation-inference
Example:
```python
>>> from text_generation import InferenceAPIAsyncClient
>>> client = InferenceAPIAsyncClient("bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
' Rayleigh scattering'
```
"""
def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):
"""
Init headers and API information
Args:
repo_id (`str`):
Id of repository (e.g. `bigscience/bloom`).
token (`str`, `optional`):
The API token to use as HTTP bearer authorization. This is not
the authentication token. You can find the token in
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
timeout (`int`):
Timeout in seconds
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and repo_id not in supported_models:
raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(
base_url, headers=headers, timeout=timeout
)

View File

@ -0,0 +1,223 @@
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation.errors import ValidationError
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# Whether to prepend the prompt to the generated text
return_full_text: bool = False
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = []
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False
# Get generation details
details: bool = False
@validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
@validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
@validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
@validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
class Request(BaseModel):
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool = False
@validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
@validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value
# Prompt tokens
class PrefillToken(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token(BaseModel):
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence(BaseModel):
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# `generate` details
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Prompt tokens
prefill: List[PrefillToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse(BaseModel):
# Generated token
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]

View File

@ -11,7 +11,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.2.1" "version": "0.4.0"
}, },
"paths": { "paths": {
"/generate": { "/generate": {
@ -38,23 +38,17 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/GenerateResponse" "$ref": "#/components/schemas/GenerateResponse"
} }
} }
} }
}
}, },
"422": { "422": {
"description": "Input validation error", "description": "Input validation error",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Input validation error" "error": "Input validation error"
@ -67,10 +61,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Request failed during generation" "error": "Request failed during generation"
@ -83,10 +74,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Model is overloaded" "error": "Model is overloaded"
@ -99,10 +87,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Incomplete generation" "error": "Incomplete generation"
@ -136,25 +121,19 @@
"200": { "200": {
"description": "Generated Text", "description": "Generated Text",
"content": { "content": {
"text/event-stream ": { "text/event-stream": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/StreamResponse" "$ref": "#/components/schemas/StreamResponse"
} }
} }
} }
}
}, },
"422": { "422": {
"description": "Input validation error", "description": "Input validation error",
"content": { "content": {
"text/event-stream ": { "text/event-stream": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Input validation error" "error": "Input validation error"
@ -165,12 +144,9 @@
"424": { "424": {
"description": "Generation Error", "description": "Generation Error",
"content": { "content": {
"text/event-stream ": { "text/event-stream": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Request failed during generation" "error": "Request failed during generation"
@ -181,12 +157,9 @@
"429": { "429": {
"description": "Model is overloaded", "description": "Model is overloaded",
"content": { "content": {
"text/event-stream ": { "text/event-stream": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Model is overloaded" "error": "Model is overloaded"
@ -197,12 +170,9 @@
"500": { "500": {
"description": "Incomplete generation", "description": "Incomplete generation",
"content": { "content": {
"text/event-stream ": { "text/event-stream": {
"schema": { "schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse" "$ref": "#/components/schemas/ErrorResponse"
}
}, },
"example": { "example": {
"error": "Incomplete generation" "error": "Incomplete generation"
@ -213,17 +183,90 @@
}, },
"deprecated": false "deprecated": false
} }
},
"/metrics": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics",
"responses": {
"200": {
"description": "Prometheus Metrics",
"content": {
"text/plain": {
"schema": {
"type": "string"
}
}
}
}
},
"deprecated": false
}
} }
}, },
"components": { "components": {
"schemas": { "schemas": {
"BestOfSequence": {
"type": "object",
"required": [
"generated_text",
"finish_reason",
"generated_tokens",
"prefill",
"tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_text": {
"type": "string",
"example": "test"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/PrefillToken"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
},
"tokens": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
}
}
},
"Details": { "Details": {
"type": "object", "type": "object",
"required": [ "required": [
"finish_reason", "finish_reason",
"generated_tokens" "generated_tokens",
"prefill",
"tokens"
], ],
"properties": { "properties": {
"best_of_sequences": {
"type": "array",
"items": {
"$ref": "#/components/schemas/BestOfSequence"
}
},
"finish_reason": { "finish_reason": {
"$ref": "#/components/schemas/FinishReason" "$ref": "#/components/schemas/FinishReason"
}, },
@ -235,13 +278,14 @@
"prefill": { "prefill": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/Token" "$ref": "#/components/schemas/PrefillToken"
} }
}, },
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42 "example": 42,
"nullable": true
}, },
"tokens": { "tokens": {
"type": "array", "type": "array",
@ -254,11 +298,15 @@
"ErrorResponse": { "ErrorResponse": {
"type": "object", "type": "object",
"required": [ "required": [
"error" "error",
"error_type"
], ],
"properties": { "properties": {
"error": { "error": {
"type": "string" "type": "string"
},
"error_type": {
"type": "string"
} }
} }
}, },
@ -273,6 +321,13 @@
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"best_of": {
"type": "integer",
"default": "null",
"example": 1,
"nullable": true,
"exclusiveMinimum": 0.0
},
"details": { "details": {
"type": "boolean", "type": "boolean",
"default": "true" "default": "true"
@ -297,9 +352,19 @@
"nullable": true, "nullable": true,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
}, },
"return_full_text": {
"type": "boolean",
"default": "null",
"example": false,
"nullable": true
},
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64" "format": "int64",
"default": "null",
"example": "null",
"nullable": true,
"exclusiveMinimum": 0.0
}, },
"stop": { "stop": {
"type": "array", "type": "array",
@ -335,6 +400,26 @@
"nullable": true, "nullable": true,
"maximum": 1.0, "maximum": 1.0,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
},
"truncate": {
"type": "integer",
"default": "null",
"example": "null",
"nullable": true
},
"typical_p": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.95,
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
},
"watermark": {
"type": "boolean",
"default": "false",
"example": true
} }
} }
}, },
@ -368,6 +453,31 @@
} }
} }
}, },
"PrefillToken": {
"type": "object",
"required": [
"id",
"text",
"logprob"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0
},
"logprob": {
"type": "number",
"format": "float",
"example": -0.34,
"nullable": true
},
"text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": { "StreamDetails": {
"type": "object", "type": "object",
"required": [ "required": [
@ -386,7 +496,8 @@
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42 "example": 42,
"nullable": true
} }
} }
}, },
@ -415,7 +526,8 @@
"required": [ "required": [
"id", "id",
"text", "text",
"logprob" "logprob",
"special"
], ],
"properties": { "properties": {
"id": { "id": {
@ -429,6 +541,10 @@
"example": -0.34, "example": -0.34,
"nullable": true "nullable": true
}, },
"special": {
"type": "boolean",
"example": "false"
},
"text": { "text": {
"type": "string", "type": "string",
"example": "test" "example": "test"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.2.1" version = "0.4.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
description = "Text Generation Launcher" description = "Text Generation Launcher"

View File

@ -1,6 +1,7 @@
use clap::Parser; use clap::Parser;
use serde_json::Value; use serde_json::Value;
use std::env; use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
use std::process::ExitCode; use std::process::ExitCode;
@ -12,7 +13,7 @@ use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection}; use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -23,13 +24,21 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(long, env)] #[clap(long, env)]
sharded: Option<bool>,
#[clap(long, env)]
num_shard: Option<usize>, num_shard: Option<usize>,
#[clap(long, env)] #[clap(long, env)]
quantize: bool, quantize: bool,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)] #[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
@ -43,38 +52,112 @@ struct Args {
#[clap(default_value = "29500", long, env)] #[clap(default_value = "29500", long, env)]
master_port: usize, master_port: usize,
#[clap(long, env)] #[clap(long, env)]
huggingface_hub_cache: Option<String>,
#[clap(long, env)]
weights_cache_override: Option<String>,
#[clap(long, env)]
disable_custom_kernels: bool,
#[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
watermark_delta: Option<f32>,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
// Pattern match configuration // Pattern match configuration
let args = Args::parse();
if args.json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
tracing::info!("{:?}", args);
let Args { let Args {
model_id, model_id,
revision, revision,
sharded,
num_shard, num_shard,
quantize, quantize,
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
max_batch_size, max_batch_size,
max_waiting_tokens, max_waiting_tokens,
port, port,
shard_uds_path, shard_uds_path,
master_addr, master_addr,
master_port, master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
} = Args::parse(); cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args;
if json_output { // get the number of shards given `sharded` and `num_shard`
tracing_subscriber::fmt().json().init(); let num_shard = if let Some(sharded) = sharded {
// sharded is set
match sharded {
// sharded is set and true
true => {
match num_shard {
None => {
// try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
let n_devices = num_cuda_devices()
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices");
}
n_devices
}
Some(num_shard) => {
// we can't have only one shard while sharded
if num_shard <= 1 {
panic!("`sharded` is true but `num_shard` <= 1");
}
num_shard
}
}
}
// sharded is set and false
false => {
let num_shard = num_shard.unwrap_or(1);
// we can't have more than one shard while not sharded
if num_shard != 1 {
panic!("`sharded` is false but `num_shard` != 1");
}
num_shard
}
}
} else { } else {
tracing_subscriber::fmt().compact().init(); match num_shard {
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
None => num_cuda_devices().unwrap_or(1),
Some(num_shard) => num_shard,
}
};
if num_shard < 1 {
panic!("`num_shard` cannot be < 1");
} }
// By default we only have one master shard if num_shard > 1 {
let num_shard = num_shard.unwrap_or(1); tracing::info!("Sharding model on {num_shard} processes");
}
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
@ -84,6 +167,121 @@ fn main() -> ExitCode {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Check if model_id is a local model
let local_path = Path::new(&model_id);
let is_local_model = local_path.exists() && local_path.is_dir();
// Download weights for sharded models
if !is_local_model && weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
// Model optional revision
if let Some(ref revision) = revision {
download_argv.push("--revision".to_string());
download_argv.push(revision.to_string())
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the shard
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// Start process
tracing::info!("Starting download process.");
let mut download_process = match Popen::create(
&download_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
env: Some(env),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
return ExitCode::FAILURE;
}
};
// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
thread::spawn(move || {
// Enter download tracing span
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
}
}
});
loop {
if let Some(status) = download_process.poll() {
match status {
ExitStatus::Exited(exit_code) => {
if exit_code == 0 {
tracing::info!("Successfully downloaded weights.");
break;
} else {
let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
tracing::error!("Download encountered an error: {err}");
return ExitCode::FAILURE;
}
}
_ => {
tracing::error!("Download process exited with an unknown status.");
return ExitCode::FAILURE;
}
}
}
if !running.load(Ordering::SeqCst) {
download_process.terminate().unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process
.wait_timeout(Duration::from_secs(90))
.unwrap();
tracing::info!("Download process terminated");
return ExitCode::SUCCESS;
}
sleep(Duration::from_millis(100));
}
}
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel // Shared shutdown channel
@ -99,6 +297,8 @@ fn main() -> ExitCode {
let revision = revision.clone(); let revision = revision.clone();
let uds_path = shard_uds_path.clone(); let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone(); let master_addr = master_addr.clone();
let huggingface_hub_cache = huggingface_hub_cache.clone();
let weights_cache_override = weights_cache_override.clone();
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
@ -113,6 +313,11 @@ fn main() -> ExitCode {
num_shard, num_shard,
master_addr, master_addr,
master_port, master_port,
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
watermark_gamma,
watermark_delta,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
@ -161,8 +366,14 @@ fn main() -> ExitCode {
"text-generation-router".to_string(), "text-generation-router".to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(), max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
max_best_of.to_string(),
"--max-stop-sequences".to_string(),
max_stop_sequences.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),
max_input_length.to_string(), max_input_length.to_string(),
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
"--max-batch-size".to_string(), "--max-batch-size".to_string(),
max_batch_size.to_string(), max_batch_size.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -185,6 +396,12 @@ fn main() -> ExitCode {
argv.push(otlp_endpoint); argv.push(otlp_endpoint);
} }
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
let mut webserver = match Popen::create( let mut webserver = match Popen::create(
&argv, &argv,
PopenConfig { PopenConfig {
@ -232,7 +449,7 @@ fn main() -> ExitCode {
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err); tracing::error!("Shard {rank} failed:\n{err}");
exit_code = ExitCode::FAILURE; exit_code = ExitCode::FAILURE;
break; break;
}; };
@ -275,6 +492,11 @@ fn shard_manager(
world_size: usize, world_size: usize,
master_addr: String, master_addr: String,
master_port: usize, master_port: usize,
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<Mutex<bool>>,
@ -319,43 +541,54 @@ fn shard_manager(
shard_argv.push(otlp_endpoint); shard_argv.push(otlp_endpoint);
} }
let mut env = vec![ // Copy current process env
("RANK".into(), rank.to_string().into()), let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
("WORLD_SIZE".into(), world_size.to_string().into()),
("MASTER_ADDR".into(), master_addr.into()),
("MASTER_PORT".into(), master_port.to_string().into()),
("SAFETENSORS_FAST_GPU".into(), "1".into()),
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard // If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") { if let Some(weights_cache_override) = weights_cache_override {
env.push(( env.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
)); ));
}; };
// If the NCCL_SHM_DISABLE env var is set, pass it to the shard // If disable_custom_kernels is true, pass it to the shard as an env var
// needed when running NCCL inside a docker container and when you can't increase shm size if disable_custom_kernels {
if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") { env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
env.push(("NCCL_SHM_DISABLE".into(), nccl_shm_disalbe.into())); }
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard // Watermark Gamma
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { if let Some(watermark_gamma) = watermark_gamma {
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into())); env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}; }
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process // Start process
tracing::info!("Starting shard {}", rank); tracing::info!("Starting shard {rank}");
let mut p = match Popen::create( let mut p = match Popen::create(
&shard_argv, &shard_argv,
PopenConfig { PopenConfig {
@ -419,17 +652,17 @@ fn shard_manager(
if *shutdown.lock().unwrap() { if *shutdown.lock().unwrap() {
p.terminate().unwrap(); p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90)); let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank); tracing::info!("Shard {rank} terminated");
return; return;
} }
// Shard is ready // Shard is ready
if uds.exists() && !ready { if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap(); status_sender.send(ShardStatus::Ready).unwrap();
ready = true; ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) { } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank); tracing::info!("Waiting for shard {rank} to be ready...");
wait_time = Instant::now(); wait_time = Instant::now();
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
@ -449,3 +682,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
// This will block till all shutdown_sender are dropped // This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv(); let _ = shutdown_receiver.recv();
} }
fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(',').count();
return Some(n_devices);
}
None
}

View File

@ -1,122 +1,142 @@
{ {
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": null,
"prefill": [ "prefill": [
{ {
"id": 10264, "id": 10264,
"logprob": null, "text": "Test",
"text": "Test" "logprob": null
}, },
{ {
"id": 8821, "id": 8821,
"logprob": -11.894989, "text": " request",
"text": " request" "logprob": -11.894989
} }
], ],
"seed": null,
"tokens": [ "tokens": [
{ {
"id": 17, "id": 17,
"text": ".",
"logprob": -1.8267672, "logprob": -1.8267672,
"text": "." "special": false
}, },
{ {
"id": 1587, "id": 1587,
"text": "get",
"logprob": -2.4674969, "logprob": -2.4674969,
"text": "get" "special": false
}, },
{ {
"id": 11, "id": 11,
"text": "(",
"logprob": -1.906001, "logprob": -1.906001,
"text": "(" "special": false
}, },
{ {
"id": 5, "id": 5,
"text": "\"",
"logprob": -1.2279545, "logprob": -1.2279545,
"text": "\"" "special": false
}, },
{ {
"id": 4899, "id": 4899,
"text": "action",
"logprob": -4.170299, "logprob": -4.170299,
"text": "action" "special": false
}, },
{ {
"id": 5, "id": 5,
"text": "\"",
"logprob": -0.32478866, "logprob": -0.32478866,
"text": "\"" "special": false
}, },
{ {
"id": 12, "id": 12,
"text": ")",
"logprob": -1.0773665, "logprob": -1.0773665,
"text": ")" "special": false
}, },
{ {
"id": 30, "id": 30,
"text": ";",
"logprob": -0.27640742, "logprob": -0.27640742,
"text": ";" "special": false
}, },
{ {
"id": 837, "id": 837,
"text": "\n ",
"logprob": -1.6970354, "logprob": -1.6970354,
"text": "\n " "special": false
}, },
{ {
"id": 1320, "id": 1320,
"text": " if",
"logprob": -1.4495516, "logprob": -1.4495516,
"text": " if" "special": false
}, },
{ {
"id": 375, "id": 375,
"text": " (",
"logprob": -0.23609057, "logprob": -0.23609057,
"text": " (" "special": false
}, },
{ {
"id": 4899, "id": 4899,
"text": "action",
"logprob": -1.1916996, "logprob": -1.1916996,
"text": "action" "special": false
}, },
{ {
"id": 3535, "id": 3535,
"text": " ==",
"logprob": -0.8918753, "logprob": -0.8918753,
"text": " ==" "special": false
}, },
{ {
"id": 5109, "id": 5109,
"text": " null",
"logprob": -0.3933342, "logprob": -0.3933342,
"text": " null" "special": false
}, },
{ {
"id": 12, "id": 12,
"text": ")",
"logprob": -0.43212673, "logprob": -0.43212673,
"text": ")" "special": false
}, },
{ {
"id": 731, "id": 731,
"text": " {",
"logprob": -0.17702064, "logprob": -0.17702064,
"text": " {" "special": false
}, },
{ {
"id": 1260, "id": 1260,
"text": "\n ",
"logprob": -0.07027565, "logprob": -0.07027565,
"text": "\n " "special": false
}, },
{ {
"id": 10519, "id": 10519,
"text": " throw",
"logprob": -1.3915029, "logprob": -1.3915029,
"text": " throw" "special": false
}, },
{ {
"id": 2084, "id": 2084,
"text": " new",
"logprob": -0.04201372, "logprob": -0.04201372,
"text": " new" "special": false
}, },
{ {
"id": 150858, "id": 150858,
"text": " RuntimeException",
"logprob": -1.7329919, "logprob": -1.7329919,
"text": " RuntimeException" "special": false
} }
] ]
}, }
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }

View File

@ -14,6 +14,7 @@ pub struct Token {
id: u32, id: u32,
text: String, text: String,
logprob: Option<f32>, logprob: Option<f32>,
special: bool,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
{ {
assert_eq!(token.id, expected_token.id); assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text); assert_eq!(token.text, expected_token.text);
assert_eq!(token.special, expected_token.special);
if let Some(logprob) = token.logprob { if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap(); let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001); assert_float_eq!(logprob, expected_logprob, abs <= 0.001);

View File

@ -1,117 +1,137 @@
{ {
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": null,
"prefill": [ "prefill": [
{ {
"id": 0, "id": 0,
"logprob": null, "text": "<pad>",
"text": "<pad>" "logprob": null
} }
], ],
"seed": null,
"tokens": [ "tokens": [
{ {
"id": 259, "id": 259,
"text": " ",
"logprob": -1.3656927, "logprob": -1.3656927,
"text": "" "special": false
}, },
{ {
"id": 215100, "id": 215100,
"text": "\"\"\"",
"logprob": -2.6551573, "logprob": -2.6551573,
"text": "\"\"\"" "special": false
}, },
{ {
"id": 46138, "id": 46138,
"text": "Test",
"logprob": -1.8059857, "logprob": -1.8059857,
"text": "Test" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": " the",
"logprob": -1.2102449, "logprob": -1.2102449,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": " ",
"logprob": -1.6057279, "logprob": -1.6057279,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -3.6060903, "logprob": -3.6060903,
"text": "contents" "special": false
}, },
{ {
"id": 304, "id": 304,
"text": " of",
"logprob": -0.5270343, "logprob": -0.5270343,
"text": "of" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": " the",
"logprob": -0.62522805, "logprob": -0.62522805,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": " ",
"logprob": -1.4069618, "logprob": -1.4069618,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -2.621994, "logprob": -2.621994,
"text": "contents" "special": false
}, },
{ {
"id": 304, "id": 304,
"text": " of",
"logprob": -1.3172221, "logprob": -1.3172221,
"text": "of" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": " the",
"logprob": -0.3501925, "logprob": -0.3501925,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": " ",
"logprob": -0.7219573, "logprob": -0.7219573,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -1.0494149, "logprob": -1.0494149,
"text": "contents" "special": false
}, },
{ {
"id": 260, "id": 260,
"text": ".",
"logprob": -1.0803378, "logprob": -1.0803378,
"text": "." "special": false
}, },
{ {
"id": 259, "id": 259,
"text": " ",
"logprob": -0.32933083, "logprob": -0.32933083,
"text": "" "special": false
}, },
{ {
"id": 215100, "id": 215100,
"text": "\"\"\"",
"logprob": -0.11268901, "logprob": -0.11268901,
"text": "\"\"\"" "special": false
}, },
{ {
"id": 2978, "id": 2978,
"text": " test",
"logprob": -1.5846587, "logprob": -1.5846587,
"text": "test" "special": false
}, },
{ {
"id": 290, "id": 290,
"text": "_",
"logprob": -0.49796978, "logprob": -0.49796978,
"text": "_" "special": false
}, },
{ {
"id": 4125, "id": 4125,
"text": "test",
"logprob": -2.0026445, "logprob": -2.0026445,
"text": "test" "special": false
} }
] ]
}, }
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
} }

View File

@ -34,12 +34,16 @@ message NextTokenChooserParameters {
uint32 top_k = 2; uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 4; bool do_sample = 5;
/// random seed for sampling /// random seed for sampling
uint64 seed = 5; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 6; float repetition_penalty = 7;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
@ -54,12 +58,10 @@ message Request {
uint64 id = 1; uint64 id = 1;
/// The generation context /// The generation context
string inputs = 2; string inputs = 2;
/// The number of tokens inside inputs
uint32 input_length = 3;
/// Next Token Chooser Parameters /// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 3;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 4;
} }
message Batch { message Batch {
@ -108,8 +110,10 @@ message Generation {
float token_logprob = 4; float token_logprob = 4;
/// Text /// Text
string token_text = 5; string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text /// Complete generated text
GeneratedText generated_text = 6; GeneratedText generated_text = 7;
} }
message PrefillRequest { message PrefillRequest {

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-router" name = "text-generation-router"
version = "0.2.1" version = "0.4.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
description = "Text Generation Webserver" description = "Text Generation Webserver"
@ -19,17 +19,21 @@ axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26" futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152" serde = "1.0.152"
serde_json = "1.0.93" serde_json = "1.0.93"
thiserror = "1.0.38" thiserror = "1.0.38"
tokenizers = "0.13.2" tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-client" name = "text-generation-client"
version = "0.2.1" version = "0.4.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View File

@ -1,6 +1,6 @@
[package] [package]
name = "grpc-metadata" name = "grpc-metadata"
version = "0.1.0" version = "0.4.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View File

@ -1,9 +1,9 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
@ -81,6 +81,7 @@ impl Infer {
.limit_concurrent_requests .limit_concurrent_requests
.try_acquire_owned() .try_acquire_owned()
.map_err(|err| { .map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
@ -138,7 +139,7 @@ impl Infer {
.into_iter() .into_iter()
.zip(tokens.logprobs.into_iter()) .zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter()) .zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token { id, text, logprob }) .map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect(); .collect();
} }
// Push last token // Push last token
@ -172,10 +173,48 @@ impl Infer {
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}"); tracing::error!("{err}");
Err(err) Err(err)
} }
} }
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN;
for (i, response) in infer_responses.iter().enumerate() {
// mean logprobs of the generated tokens
let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence
if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob;
}
}
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
} }
/// Batching logic /// Batching logic
@ -190,7 +229,11 @@ async fn batching_task(
shared: Arc<Shared>, shared: Arc<Shared>,
) { ) {
// Minimum batch size after which we try to add more requests // Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32; let limit_min_batch_size = if max_batch_size > 1 {
(max_batch_size / 2) as u32
} else {
0
};
// Infinite loop // Infinite loop
loop { loop {
@ -201,7 +244,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries) let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -212,6 +255,7 @@ async fn batching_task(
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
// If the current batch is too small, we try to add more requests to it // If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size { if batch_size <= limit_min_batch_size {
@ -234,15 +278,15 @@ async fn batching_task(
// because a new batch is being computed // because a new batch is being computed
let entry_waiting_span = let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size); info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship // Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span); entry_waiting_span.follows_from(&span);
// Update entry // Update entry
entry.temp_span = Some(entry_waiting_span); entry.temp_span = Some(entry_waiting_span);
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
@ -262,35 +306,66 @@ async fn batching_task(
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship // Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = wrap_future(client.decode(batches), &mut entries) cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span) .instrument(next_batch_span)
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0);
} }
} }
} }
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)] #[instrument(skip_all)]
async fn wrap_future( async fn prefill(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { let start_time = Instant::now();
match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
match client.decode(batches).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None None
} }
} }
@ -303,6 +378,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
@ -340,6 +416,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
id: generation.token_id, id: generation.token_id,
text: generation.token_text, text: generation.token_text,
logprob: generation.token_logprob, logprob: generation.token_logprob,
special: generation.token_is_special,
}; };
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
@ -388,7 +465,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>, pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>, pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant, pub(crate) queued: Instant,
@ -406,3 +483,14 @@ pub enum InferError {
#[error("Incomplete generation")] #[error("Incomplete generation")]
IncompleteGeneration, IncompleteGeneration,
} }
impl InferError {
pub(crate) fn error_type(&self) -> &str {
match self {
InferError::GenerationError(_) => "generation",
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
}
}
}

View File

@ -12,6 +12,9 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -40,39 +43,64 @@ pub(crate) struct GenerateParameters {
example = 0.95 example = 0.95
)] )]
pub top_p: Option<f32>, pub top_p: Option<f32>,
#[serde(default = "default_do_sample")] #[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub typical_p: Option<f32>,
#[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json!(["photographer"]))] #[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>,
#[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub watermark: bool,
#[serde(default)]
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> u32 {
20 20
} }
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
do_sample: default_do_sample(), typical_p: None,
do_sample: false,
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
stop: vec![], return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false, details: false,
seed: None, seed: None,
} }
@ -86,14 +114,46 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct CompatGenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
#[serde(default)]
#[allow(dead_code)]
pub stream: bool,
}
impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
}
}
}
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, id: u32,
#[schema(example = "test")] #[schema(example = "test")]
text: String, text: String,
#[schema(nullable = true, example = -0.34)] #[schema(nullable = true, example = - 0.34)]
logprob: f32, logprob: f32,
#[schema(example = "false")]
special: bool,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -108,16 +168,32 @@ pub(crate) enum FinishReason {
StopSequence, StopSequence,
} }
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
#[schema(example = "test")]
pub generated_text: String,
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
#[schema(example = "length")] #[schema(example = "length")]
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)] #[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>, pub prefill: Vec<PrefillToken>,
pub tokens: Option<Vec<Token>>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -134,7 +210,7 @@ pub(crate) struct StreamDetails {
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)] #[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
@ -149,6 +225,6 @@ pub(crate) struct StreamResponse {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String, pub error: String,
pub error_type: String,
} }

View File

@ -1,4 +1,5 @@
/// Text Generation Inference webserver entrypoint /// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
@ -7,9 +8,11 @@ use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
@ -20,8 +23,14 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args { struct Args {
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(default_value = "32", long, env)] #[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
@ -38,6 +47,8 @@ struct Args {
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
@ -46,7 +57,10 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
max_batch_size, max_batch_size,
max_waiting_tokens, max_waiting_tokens,
port, port,
@ -55,17 +69,37 @@ fn main() -> Result<(), std::io::Error> {
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
cors_allow_origin,
} = args; } = args;
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
// Download and instantiate tokenizer // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Tokenizer instance
// This will only be used to validate payloads // This will only be used to validate payloads
// let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
@ -75,6 +109,27 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}"
))
.await
.expect("Could not connect to hf.co")
.text()
.await
.expect("error when retrieving model info from hf.co");
let model_info: serde_json::Value =
serde_json::from_str(&model_info).expect("unable to parse model info");
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match model_info.get("pipeline_tag") {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
false
}
Some(pipeline_tag) => pipeline_tag.as_str() == Some("text-generation"),
};
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
@ -91,14 +146,19 @@ fn main() -> Result<(), std::io::Error> {
// Run server // Run server
server::run( server::run(
compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
max_batch_size, max_batch_size,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,
validation_workers, validation_workers,
addr, addr,
cors_allow_origin,
) )
.await; .await;
Ok(()) Ok(())

View File

@ -132,6 +132,7 @@ impl State {
// Push entry in the queue // Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
} }
// Get the next batch // Get the next batch
@ -164,7 +165,8 @@ impl State {
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size); info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship // Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span); entry_batch_span.follows_from(&next_batch_span);
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
@ -172,7 +174,6 @@ impl State {
batch_requests.push(Request { batch_requests.push(Request {
id, id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
input_length: entry.request.input_length,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
}); });
@ -190,6 +191,8 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }
} }
@ -223,14 +226,15 @@ mod tests {
Entry { Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: "".to_string(),
input_length: 0,
parameters: NextTokenChooserParameters { parameters: NextTokenChooserParameters {
temperature: 0.0, temperature: 0.0,
top_k: 0, top_k: 0,
top_p: 0.0, top_p: 0.0,
typical_p: 0.0,
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
max_new_tokens: 0, max_new_tokens: 0,

View File

@ -1,17 +1,20 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{ use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
Infer, StreamDetails, StreamResponse, Token, Validation, GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream; use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
@ -19,29 +22,61 @@ use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() {
req.parameters.return_full_text = Some(default_return_full_text.0)
}
// switch on stream
if req.stream {
Ok(generate_stream(infer, Json(req.into()))
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
}
}
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check. // be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy. // What we should do instead is check if the gRPC channels are still healthy.
// Send a small inference request // Send a small inference request
infer infer
.generate(GenerateRequest { .generate(GenerateRequest {
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None,
do_sample: false, do_sample: false,
max_new_tokens: 1, max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(), stop: Vec::new(),
truncate: None,
watermark: false,
details: false, details: false,
seed: None, seed: None,
}, },
@ -57,15 +92,15 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
path = "/generate", path = "/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]), (status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = [ErrorResponse], (status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse], (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"})), example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse], (status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"})), example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse], (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"})), example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument( #[instrument(
@ -82,23 +117,64 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
// Inference let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details; let details = req.0.parameters.details;
let response = infer.generate(req.0).await?;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
};
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => {
// convert best_of_responses
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
responses
.into_iter()
.map(|response: InferResponse| {
// Add prompt if return_full_text
let mut output_text = response.generated_text.text;
if let Some(prompt) = &add_prompt {
output_text = prompt.clone() + &output_text;
}
BestOfSequence {
generated_text: output_text,
finish_reason: FinishReason::from(
response.generated_text.finish_reason,
),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}
})
.collect()
});
Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason), finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill), prefill: response.prefill,
tokens: Some(response.tokens), tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
}), best_of_sequences,
})
}
false => None, false => None,
}; };
@ -111,6 +187,15 @@ async fn generate(
// Headers // Headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
headers.insert( headers.insert(
"x-total-time", "x-total-time",
total_time.as_millis().to_string().parse().unwrap(), total_time.as_millis().to_string().parse().unwrap(),
@ -141,9 +226,26 @@ async fn generate(
span.record("seed", format!("{:?}", response.generated_text.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
// Send response // Send response
let mut output_text = response.generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
let response = GenerateResponse { let response = GenerateResponse {
generated_text: response.generated_text.text, generated_text: output_text,
details, details,
}; };
Ok((headers, Json(response))) Ok((headers, Json(response)))
@ -156,20 +258,20 @@ async fn generate(
path = "/generate_stream", path = "/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = [StreamResponse], (status = 200, description = "Generated Text", body = StreamResponse,
content_type="text/event-stream "), content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = [ErrorResponse], (status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"}), example = json ! ({"error": "Request failed during generation"}),
content_type="text/event-stream "), content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = [ErrorResponse], (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"}), example = json ! ({"error": "Model is overloaded"}),
content_type="text/event-stream "), content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = [ErrorResponse], (status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"}), example = json ! ({"error": "Input validation error"}),
content_type="text/event-stream "), content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = [ErrorResponse], (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"}), example = json ! ({"error": "Incomplete generation"}),
content_type="text/event-stream "), content_type = "text/event-stream"),
) )
)] )]
#[instrument( #[instrument(
@ -186,16 +288,35 @@ async fn generate(
async fn generate_stream( async fn generate_stream(
infer: Extension<Infer>, infer: Extension<Infer>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { ) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
let compute_characters = req.0.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
let stream = async_stream::stream! { let stream = async_stream::stream! {
// Inference // Inference
let mut end_reached = false; let mut end_reached = false;
let mut error = false; let mut error = false;
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
}
let details = req.0.parameters.details; let details = req.0.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
if best_of == 1 {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => { Ok(mut response_stream) => {
// Server-Sent Event stream // Server-Sent Event stream
@ -241,30 +362,47 @@ async fn generate_stream(
let time_per_token = inference_time / generated_text.generated_tokens; let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed)); span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
generated_text: Some(generated_text.text), generated_text: Some(output_text),
details details
}; };
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(Event::default().json_data(stream_token).unwrap());
break;
} }
} }
} }
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)) yield Ok(Event::from(err));
break;
} }
} }
} }
@ -272,32 +410,55 @@ async fn generate_stream(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)) yield Ok(Event::from(err));
} }
} }
// Check if generation reached the end // Check if generation reached the end
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)) yield Ok(Event::from(err));
}
} else {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} }
}; };
Sse::new(stream).keep_alive(KeepAlive::default()) (headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
} }
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
max_batch_size: usize, max_batch_size: usize,
max_waiting_tokens: usize, max_waiting_tokens: usize,
client: ShardedClient, client: ShardedClient,
tokenizer: Tokenizer, tokenizer: Tokenizer,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
) { ) {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -305,13 +466,16 @@ pub async fn run(
paths( paths(
generate, generate,
generate_stream, generate_stream,
metrics,
), ),
components( components(
schemas( schemas(
GenerateRequest, GenerateRequest,
GenerateParameters, GenerateParameters,
PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
BestOfSequence,
Details, Details,
FinishReason, FinishReason,
StreamResponse, StreamResponse,
@ -333,7 +497,14 @@ pub async fn run(
struct ApiDoc; struct ApiDoc;
// Create state // Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(
validation_workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
);
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -342,16 +513,33 @@ pub async fn run(
max_concurrent_requests, max_concurrent_requests,
); );
// Prometheus handler
let builder = PrometheusBuilder::new();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate)) .route("/", post(compat_generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.route("/metrics", get(metrics))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(opentelemetry_tracing_layer()); .layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
@ -415,6 +603,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
status_code, status_code,
Json(ErrorResponse { Json(ErrorResponse {
error: err.to_string(), error: err.to_string(),
error_type: err.error_type().to_string(),
}), }),
) )
} }
@ -425,6 +614,7 @@ impl From<InferError> for Event {
Event::default() Event::default()
.json_data(ErrorResponse { .json_data(ErrorResponse {
error: err.to_string(), error: err.to_string(),
error_type: err.error_type().to_string(),
}) })
.unwrap() .unwrap()
} }

View File

@ -1,3 +1,4 @@
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic /// Payload validation logic
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
@ -5,33 +6,44 @@ use rand::Rng;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{instrument, Span}; use tracing::{instrument, Span};
const MAX_MAX_NEW_TOKENS: u32 = 512;
const MAX_STOP_SEQUENCES: usize = 4;
/// Validation /// Validation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
/// maximum value for the best_of parameter
#[allow(dead_code)]
max_best_of: usize,
/// Channel to communicate with the background validation task /// Channel to communicate with the background validation task
sender: mpsc::Sender<ValidationRequest>, sender: mpsc::UnboundedSender<ValidationRequest>,
} }
impl Validation { impl Validation {
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { pub(crate) fn new(
workers: usize,
tokenizer: Tokenizer,
max_best_of: usize,
max_stop_sequences: usize,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
// Create channel // Create channel
let (validation_sender, validation_receiver) = mpsc::channel(128); let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
// Launch background validation task // Launch background validation task
tokio::spawn(validation_task( tokio::spawn(validation_task(
workers, workers,
tokenizer, tokenizer,
max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens,
validation_receiver, validation_receiver,
)); ));
Self { Self {
max_best_of,
sender: validation_sender, sender: validation_sender,
} }
} }
@ -48,12 +60,25 @@ impl Validation {
// Unwrap is safe here // Unwrap is safe here
self.sender self.sender
.send((request, sender, Span::current())) .send((request, sender, Span::current()))
.await
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
} }
/// Validate the best_of parameter
#[instrument(skip_all)]
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
if self.max_best_of == 1 && best_of != 1 {
return Err(ValidationError::BestOfDisabled);
}
if best_of > self.max_best_of {
return Err(ValidationError::BestOf(self.max_best_of, best_of));
}
Ok(best_of)
}
} }
/// Validation task /// Validation task
@ -61,8 +86,10 @@ impl Validation {
async fn validation_task( async fn validation_task(
workers: usize, workers: usize,
tokenizer: Tokenizer, tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, max_total_tokens: usize,
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
) { ) {
let mut workers_senders = Vec::with_capacity(workers); let mut workers_senders = Vec::with_capacity(workers);
@ -75,7 +102,13 @@ async fn validation_task(
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver) validation_worker(
tokenizer_clone,
max_stop_sequences,
max_input_length,
max_total_tokens,
worker_receiver,
)
}); });
} }
@ -95,7 +128,9 @@ async fn validation_task(
/// the tokenizer /// the tokenizer
fn validation_worker( fn validation_worker(
tokenizer: Tokenizer, tokenizer: Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, mut receiver: mpsc::Receiver<ValidationRequest>,
) { ) {
// Seed rng // Seed rng
@ -106,7 +141,16 @@ fn validation_worker(
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send( .send(
validate(request, &tokenizer, max_input_length, &mut rng).map_err(|err| { validate(
request,
&tokenizer,
max_stop_sequences,
max_input_length,
max_total_tokens,
&mut rng,
)
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
err err
}), }),
@ -119,21 +163,39 @@ fn validation_worker(
fn validate( fn validate(
request: GenerateRequest, request: GenerateRequest,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize,
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters { let GenerateParameters {
best_of,
temperature, temperature,
repetition_penalty, repetition_penalty,
top_k, top_k,
top_p, top_p,
typical_p,
do_sample, do_sample,
max_new_tokens, max_new_tokens,
stop: stop_sequences, stop: stop_sequences,
truncate,
seed, seed,
watermark,
.. ..
} = request.parameters; } = request.parameters;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0); let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 { if temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
@ -144,30 +206,42 @@ fn validate(
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
let top_p = top_p.unwrap_or(1.0); // Different because the proto default value is not a valid value
if top_p <= 0.0 || top_p > 1.0 { // for the user
let top_p = top_p
.map(|value| {
if value <= 0.0 || value >= 1.0 {
return Err(ValidationError::TopP); return Err(ValidationError::TopP);
} }
Ok(value)
})
.unwrap_or(Ok(1.0))?;
// Different because the proto default value is 0 while it is not a valid value let typical_p = typical_p
// for the user .map(|value| {
let top_k: u32 = match top_k { if value <= 0.0 || value >= 1.0 {
None => Ok(0), return Err(ValidationError::TypicalP);
Some(top_k) => { }
if top_k <= 0 { Ok(value)
})
.unwrap_or(Ok(1.0))?;
let top_k: u32 = top_k
.map(|value| {
if value <= 0 {
return Err(ValidationError::TopK); return Err(ValidationError::TopK);
} }
Ok(top_k as u32) Ok(value as u32)
} })
}?; .unwrap_or(Ok(0))?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS { if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); return Err(ValidationError::MaxNewTokens);
} }
if stop_sequences.len() > MAX_STOP_SEQUENCES { if stop_sequences.len() > max_stop_sequences {
return Err(ValidationError::StopSequence( return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES, max_stop_sequences,
stop_sequences.len(), stop_sequences.len(),
)); ));
} }
@ -175,41 +249,82 @@ fn validate(
// If seed is None, assign a random one // If seed is None, assign a random one
let seed = match seed { let seed = match seed {
None => rng.gen(), None => rng.gen(),
Some(seed) => seed, Some(seed) => {
if best_of > 1 {
return Err(BestOfSeed);
}
seed
}
}; };
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
}
// Check if truncate is strictly positive and less than max_input_length
let truncate = truncate
.map(|value| {
if value == 0 || value > max_input_length {
return Err(ValidationError::Truncate(max_input_length, value));
}
Ok(Some(value))
})
.unwrap_or(Ok(None))?;
// Get the number of tokens in the input // Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), true) { let mut encoding = tokenizer
Ok(encoding) => { .encode(request.inputs.clone(), true)
let input_length = encoding.len(); .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
let (inputs, input_length) = if let Some(truncate) = truncate {
// truncate encoding and decode new inputs
encoding.truncate(truncate, 0, TruncationDirection::Left);
let inputs = tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
(inputs, encoding.len())
} else {
(request.inputs, encoding.len())
};
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength(input_length, max_input_length)) return Err(ValidationError::InputLength(max_input_length, input_length));
} else { }
let total_tokens = input_length + max_new_tokens as usize;
if total_tokens > max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
max_total_tokens,
input_length,
max_new_tokens,
));
}
// Return ValidGenerateRequest // Return ValidGenerateRequest
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
top_k, top_k,
top_p, top_p,
typical_p,
do_sample, do_sample,
seed, seed,
watermark,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
stop_sequences, stop_sequences,
}; };
metrics::histogram!("tgi_request_input_length", input_length as f64);
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs: request.inputs, inputs,
input_length: input_length as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
}) })
}
}
Err(err) => Err(ValidationError::Tokenizer(err.to_string())),
}
} }
type ValidationRequest = ( type ValidationRequest = (
@ -221,26 +336,43 @@ type ValidationRequest = (
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ValidGenerateRequest { pub(crate) struct ValidGenerateRequest {
pub inputs: String, pub inputs: String,
pub input_length: u32,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
BestOf(usize, usize),
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
#[error("you must use sampling when `best_of` is > 1")]
BestOfSampling,
#[error("`seed` must not be set when `best_of` > 1")]
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("repetition_penalty must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")] #[error("`top_p` must be > 0.0 and < 1.0")]
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("`top_k` must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be strictly positive and <= {0}")] #[error("`truncate` must be strictly positive and less than {0}. Given: {1}")]
MaxNewTokens(u32), Truncate(usize, usize),
#[error("inputs must have less than {1} tokens. Given: {0}")] #[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32),
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
InputLength(usize, usize), InputLength(usize, usize),
#[error("stop supports up to {0} stop sequences. Given: {1}")] #[error("`inputs` cannot be empty")]
EmptyInput,
#[error("`stop` supports up to {0} stop sequences. Given: {1}")]
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
text_generation/__pycache__/ text_generation_server/__pycache__/
text_generation/pb/__pycache__/ text_generation_server/pb/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class

View File

@ -1,20 +1,22 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers: install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true pip uninstall transformers -y || true
rm -rf transformers || true rm -rf transformers || true
rm -rf transformers-text_generation_inference || true rm -rf transformers-$(transformers_commit) || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip
unzip text_generation_inference.zip unzip $(transformers_commit).zip
rm text_generation_inference.zip rm $(transformers_commit).zip
mv transformers-text_generation_inference transformers mv transformers-$(transformers_commit) transformers
cd transformers && python setup.py install cd transformers && python setup.py install
install-torch: install-torch:
@ -26,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

417
server/poetry.lock generated
View File

@ -145,30 +145,30 @@ testing = ["protobuf (>=3.6.0)"]
[[package]] [[package]]
name = "grpcio" name = "grpcio"
version = "1.51.1" version = "1.51.3"
description = "HTTP/2-based RPC framework" description = "HTTP/2-based RPC framework"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.51.1)"] protobuf = ["grpcio-tools (>=1.51.3)"]
[[package]] [[package]]
name = "grpcio-reflection" name = "grpcio-reflection"
version = "1.51.1" version = "1.51.3"
description = "Standard Protobuf Reflection Service for gRPC" description = "Standard Protobuf Reflection Service for gRPC"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
grpcio = ">=1.51.1" grpcio = ">=1.51.3"
protobuf = ">=4.21.6" protobuf = ">=4.21.6"
[[package]] [[package]]
name = "grpcio-status" name = "grpcio-status"
version = "1.51.1" version = "1.51.3"
description = "Status proto mapping for gRPC" description = "Status proto mapping for gRPC"
category = "main" category = "main"
optional = false optional = false
@ -176,22 +176,30 @@ python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
googleapis-common-protos = ">=1.5.5" googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.51.1" grpcio = ">=1.51.3"
protobuf = ">=4.21.6" protobuf = ">=4.21.6"
[[package]] [[package]]
name = "grpcio-tools" name = "grpcio-tools"
version = "1.51.1" version = "1.51.3"
description = "Protobuf code generator for gRPC" description = "Protobuf code generator for gRPC"
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[package.dependencies] [package.dependencies]
grpcio = ">=1.51.1" grpcio = ">=1.51.3"
protobuf = ">=4.21.6,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "hf-transfer"
version = "0.1.2"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.4" version = "3.4"
@ -428,7 +436,7 @@ testing = ["pytest", "pytest-benchmark"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.21.12" version = "4.22.0"
description = "" description = ""
category = "main" category = "main"
optional = false optional = false
@ -511,7 +519,7 @@ torch = ["torch"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "67.2.0" version = "67.4.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages" description = "Easily download, build, install, upgrade, and uninstall Python packages"
category = "main" category = "main"
optional = false optional = false
@ -567,7 +575,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.4.0" version = "4.5.0"
description = "Backported and Experimental Type Hints for Python 3.7+" description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main" category = "main"
optional = false optional = false
@ -610,7 +618,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]] [[package]]
name = "wrapt" name = "wrapt"
version = "1.14.1" version = "1.15.0"
description = "Module for decorators, wrappers and monkey patching." description = "Module for decorators, wrappers and monkey patching."
category = "main" category = "main"
optional = false optional = false
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321" content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
@ -760,106 +768,127 @@ grpc-interceptor = [
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"}, {file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"},
] ]
grpcio = [ grpcio = [
{file = "grpcio-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:cc2bece1737b44d878cc1510ea04469a8073dbbcdd762175168937ae4742dfb3"}, {file = "grpcio-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:f601aaeae18dab81930fb8d4f916b0da21e89bb4b5f7367ef793f46b4a76b7b0"},
{file = "grpcio-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:e223a9793522680beae44671b9ed8f6d25bbe5ddf8887e66aebad5e0686049ef"}, {file = "grpcio-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:eef0450a4b5ed11feab639bf3eb1b6e23d0efa9b911bf7b06fb60e14f5f8a585"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:24ac1154c4b2ab4a0c5326a76161547e70664cd2c39ba75f00fc8a2170964ea2"}, {file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:82b0ad8ac825d4bb31bff9f638557c045f4a6d824d84b21e893968286f88246b"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4ef09f8997c4be5f3504cefa6b5c6cc3cf648274ce3cede84d4342a35d76db6"}, {file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3667c06e37d6cd461afdd51cefe6537702f3d1dc5ff4cac07e88d8b4795dc16f"},
{file = "grpcio-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8a0b77e992c64880e6efbe0086fe54dfc0bbd56f72a92d9e48264dcd2a3db98"}, {file = "grpcio-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3709048fe0aa23dda09b3e69849a12055790171dab9e399a72ea8f9dfbf9ac80"},
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:eacad297ea60c72dd280d3353d93fb1dcca952ec11de6bb3c49d12a572ba31dd"}, {file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:200d69857f9910f7458b39b9bcf83ee4a180591b40146ba9e49314e3a7419313"},
{file = "grpcio-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:16c71740640ba3a882f50b01bf58154681d44b51f09a5728180a8fdc66c67bd5"}, {file = "grpcio-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cd9a5e68e79c5f031500e67793048a90209711e0854a9ddee8a3ce51728de4e5"},
{file = "grpcio-1.51.1-cp310-cp310-win32.whl", hash = "sha256:29cb97d41a4ead83b7bcad23bdb25bdd170b1e2cba16db6d3acbb090bc2de43c"}, {file = "grpcio-1.51.3-cp310-cp310-win32.whl", hash = "sha256:6604f614016127ae10969176bbf12eb0e03d2fb3d643f050b3b69e160d144fb4"},
{file = "grpcio-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:9ff42c5620b4e4530609e11afefa4a62ca91fa0abb045a8957e509ef84e54d30"}, {file = "grpcio-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:e95c7ccd4c5807adef1602005513bf7c7d14e5a41daebcf9d8d30d8bf51b8f81"},
{file = "grpcio-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:bc59f7ba87972ab236f8669d8ca7400f02a0eadf273ca00e02af64d588046f02"}, {file = "grpcio-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:5e77ee138100f0bb55cbd147840f87ee6241dbd25f09ea7cd8afe7efff323449"},
{file = "grpcio-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3c2b3842dcf870912da31a503454a33a697392f60c5e2697c91d133130c2c85d"}, {file = "grpcio-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:68a7514b754e38e8de9075f7bb4dee919919515ec68628c43a894027e40ddec4"},
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22b011674090594f1f3245960ced7386f6af35485a38901f8afee8ad01541dbd"}, {file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c1b9f8afa62ff265d86a4747a2990ec5a96e4efce5d5888f245a682d66eca47"},
{file = "grpcio-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d680356a975d9c66a678eb2dde192d5dc427a7994fb977363634e781614f7c"}, {file = "grpcio-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8de30f0b417744288cec65ec8cf84b8a57995cf7f1e84ccad2704d93f05d0aae"},
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:094e64236253590d9d4075665c77b329d707b6fca864dd62b144255e199b4f87"}, {file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b69c7adc7ed60da1cb1b502853db61f453fc745f940cbcc25eb97c99965d8f41"},
{file = "grpcio-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:257478300735ce3c98d65a930bbda3db172bd4e00968ba743e6a1154ea6edf10"}, {file = "grpcio-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d81528ffe0e973dc840ec73a4132fd18b8203ad129d7410155d951a0a7e4f5d0"},
{file = "grpcio-1.51.1-cp311-cp311-win32.whl", hash = "sha256:5a6ebcdef0ef12005d56d38be30f5156d1cb3373b52e96f147f4a24b0ddb3a9d"}, {file = "grpcio-1.51.3-cp311-cp311-win32.whl", hash = "sha256:040eb421613b57c696063abde405916dd830203c184c9000fc8c3b3b3c950325"},
{file = "grpcio-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f9b0023c2c92bebd1be72cdfca23004ea748be1813a66d684d49d67d836adde"}, {file = "grpcio-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:2a8e17286c4240137d933b8ca506465472248b4ce0fe46f3404459e708b65b68"},
{file = "grpcio-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:cd3baccea2bc5c38aeb14e5b00167bd4e2373a373a5e4d8d850bd193edad150c"}, {file = "grpcio-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d5cd1389669a847555df54177b911d9ff6f17345b2a6f19388707b7a9f724c88"},
{file = "grpcio-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:17ec9b13cec4a286b9e606b48191e560ca2f3bbdf3986f91e480a95d1582e1a7"}, {file = "grpcio-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:be1bf35ce82cdbcac14e39d5102d8de4079a1c1a6a06b68e41fcd9ef64f9dd28"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:fbdbe9a849854fe484c00823f45b7baab159bdd4a46075302281998cb8719df5"}, {file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:5eed34994c095e2bf7194ffac7381c6068b057ef1e69f8f08db77771350a7566"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31bb6bc7ff145e2771c9baf612f4b9ebbc9605ccdc5f3ff3d5553de7fc0e0d79"}, {file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9a7d88082b2a17ae7bd3c2354d13bab0453899e0851733f6afa6918373f476"},
{file = "grpcio-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e473525c28251558337b5c1ad3fa969511e42304524a4e404065e165b084c9e4"}, {file = "grpcio-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c8abbc5f837111e7bd619612eedc223c290b0903b952ce0c7b00840ea70f14"},
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:6f0b89967ee11f2b654c23b27086d88ad7bf08c0b3c2a280362f28c3698b2896"}, {file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:165b05af77e6aecb4210ae7663e25acf234ba78a7c1c157fa5f2efeb0d6ec53c"},
{file = "grpcio-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7942b32a291421460d6a07883033e392167d30724aa84987e6956cd15f1a21b9"}, {file = "grpcio-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:54e36c2ee304ff15f2bfbdc43d2b56c63331c52d818c364e5b5214e5bc2ad9f6"},
{file = "grpcio-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:f96ace1540223f26fbe7c4ebbf8a98e3929a6aa0290c8033d12526847b291c0f"}, {file = "grpcio-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:cd0daac21d9ef5e033a5100c1d3aa055bbed28bfcf070b12d8058045c4e821b1"},
{file = "grpcio-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f1fec3abaf274cdb85bf3878167cfde5ad4a4d97c68421afda95174de85ba813"}, {file = "grpcio-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:2fdd6333ce96435408565a9dbbd446212cd5d62e4d26f6a3c0feb1e3c35f1cc8"},
{file = "grpcio-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:0e1a9e1b4a23808f1132aa35f968cd8e659f60af3ffd6fb00bcf9a65e7db279f"}, {file = "grpcio-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:54b0c29bdd9a3b1e1b61443ab152f060fc719f1c083127ab08d03fac5efd51be"},
{file = "grpcio-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:6df3b63538c362312bc5fa95fb965069c65c3ea91d7ce78ad9c47cab57226f54"}, {file = "grpcio-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:ffaaf7e93fcb437356b5a4b23bf36e8a3d0221399ff77fd057e4bc77776a24be"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:172405ca6bdfedd6054c74c62085946e45ad4d9cec9f3c42b4c9a02546c4c7e9"}, {file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:eafbe7501a3268d05f2e450e1ddaffb950d842a8620c13ec328b501d25d2e2c3"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:506b9b7a4cede87d7219bfb31014d7b471cfc77157da9e820a737ec1ea4b0663"}, {file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:881ecb34feabf31c6b3b9bbbddd1a5b57e69f805041e5a2c6c562a28574f71c4"},
{file = "grpcio-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb93051331acbb75b49a2a0fd9239c6ba9528f6bdc1dd400ad1cb66cf864292"}, {file = "grpcio-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e860a3222139b41d430939bbec2ec9c3f6c740938bf7a04471a9a8caaa965a2e"},
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5dca372268c6ab6372d37d6b9f9343e7e5b4bc09779f819f9470cd88b2ece3c3"}, {file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:49ede0528e9dac7e8a9fe30b16c73b630ddd9a576bf4b675eb6b0c53ee5ca00f"},
{file = "grpcio-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:471d39d3370ca923a316d49c8aac66356cea708a11e647e3bdc3d0b5de4f0a40"}, {file = "grpcio-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6972b009638b40a448d10e1bc18e2223143b8a7aa20d7def0d78dd4af4126d12"},
{file = "grpcio-1.51.1-cp38-cp38-win32.whl", hash = "sha256:75e29a90dc319f0ad4d87ba6d20083615a00d8276b51512e04ad7452b5c23b04"}, {file = "grpcio-1.51.3-cp38-cp38-win32.whl", hash = "sha256:5694448256e3cdfe5bd358f1574a3f2f51afa20cc834713c4b9788d60b7cc646"},
{file = "grpcio-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:f1158bccbb919da42544a4d3af5d9296a3358539ffa01018307337365a9a0c64"}, {file = "grpcio-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:3ea4341efe603b049e8c9a5f13c696ca37fcdf8a23ca35f650428ad3606381d9"},
{file = "grpcio-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:59dffade859f157bcc55243714d57b286da6ae16469bf1ac0614d281b5f49b67"}, {file = "grpcio-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:6c677581ce129f5fa228b8f418cee10bd28dd449f3a544ea73c8ba590ee49d0b"},
{file = "grpcio-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:dad6533411d033b77f5369eafe87af8583178efd4039c41d7515d3336c53b4f1"}, {file = "grpcio-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:30e09b5e0531685e176f49679b6a3b190762cc225f4565e55a899f5e14b3aa62"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:4c4423ea38a7825b8fed8934d6d9aeebdf646c97e3c608c3b0bcf23616f33877"}, {file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c831f31336e81243f85b6daff3e5e8a123302ce0ea1f2726ad752fd7a59f3aee"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0dc5354e38e5adf2498312f7241b14c7ce3484eefa0082db4297189dcbe272e6"}, {file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cd2e4cefb724cab1ba2df4b7535a9980531b9ec51b4dbb5f137a1f3a3754ef0"},
{file = "grpcio-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97d67983189e2e45550eac194d6234fc38b8c3b5396c153821f2d906ed46e0ce"}, {file = "grpcio-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7a0d0bf44438869d307f85a54f25a896ad6b4b0ca12370f76892ad732928d87"},
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:538d981818e49b6ed1e9c8d5e5adf29f71c4e334e7d459bf47e9b7abb3c30e09"}, {file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c02abd55409bfb293371554adf6a4401197ec2133dd97727c01180889014ba4d"},
{file = "grpcio-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9235dcd5144a83f9ca6f431bd0eccc46b90e2c22fe27b7f7d77cabb2fb515595"}, {file = "grpcio-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2f8ff75e61e1227ba7a3f16b2eadbcc11d0a54096d52ab75a6b88cfbe56f55d1"},
{file = "grpcio-1.51.1-cp39-cp39-win32.whl", hash = "sha256:aacb54f7789ede5cbf1d007637f792d3e87f1c9841f57dd51abf89337d1b8472"}, {file = "grpcio-1.51.3-cp39-cp39-win32.whl", hash = "sha256:6c99a73a6260bdf844b2e5ddad02dcd530310f80e1fa72c300fa19c1c7496962"},
{file = "grpcio-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:2b170eaf51518275c9b6b22ccb59450537c5a8555326fd96ff7391b5dd75303c"}, {file = "grpcio-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:22bdfac4f7f27acdd4da359b5e7e1973dc74bf1ed406729b07d0759fde2f064b"},
{file = "grpcio-1.51.1.tar.gz", hash = "sha256:e6dfc2b6567b1c261739b43d9c59d201c1b89e017afd9e684d85aa7a186c9f7a"}, {file = "grpcio-1.51.3.tar.gz", hash = "sha256:be7b2265b7527bb12109a7727581e274170766d5b3c9258d4e466f4872522d7a"},
] ]
grpcio-reflection = [ grpcio-reflection = [
{file = "grpcio-reflection-1.51.1.tar.gz", hash = "sha256:c07a93c0c36ef88fe475744289863b4787005eff4de0cc04213ecad718b01aae"}, {file = "grpcio-reflection-1.51.3.tar.gz", hash = "sha256:5adca16f0a6cd403efa3b5f8f8a493eea6a37dee9473b178fad0a60efa68bc67"},
{file = "grpcio_reflection-1.51.1-py3-none-any.whl", hash = "sha256:b70af764a83e42a44f65df1edb232e972ab69e72bc7fbbad481e66c29a9d8cb8"}, {file = "grpcio_reflection-1.51.3-py3-none-any.whl", hash = "sha256:52b037f831908468afc89c60e591d0a2bbce24a393d908c44a6d53091e90fc41"},
] ]
grpcio-status = [ grpcio-status = [
{file = "grpcio-status-1.51.1.tar.gz", hash = "sha256:ac2617a3095935ebd785e2228958f24b10a0d527a0c9eb5a0863c784f648a816"}, {file = "grpcio-status-1.51.3.tar.gz", hash = "sha256:71792c550356ba94e162c70818719ae6d67d960bdd03a9db5ff68faba2927f6c"},
{file = "grpcio_status-1.51.1-py3-none-any.whl", hash = "sha256:a52cbdc4b18f325bfc13d319ae7c7ae7a0fee07f3d9a005504d6097896d7a495"}, {file = "grpcio_status-1.51.3-py3-none-any.whl", hash = "sha256:d68d0956c16b6ea466f13c27075f126ef2cd8f0f97527d70056c64b0084357e3"},
] ]
grpcio-tools = [ grpcio-tools = [
{file = "grpcio-tools-1.51.1.tar.gz", hash = "sha256:8e62d23d3fed9d4f81738f98dd193dbd2e21aed4a8f0dd715e75b5439e649727"}, {file = "grpcio-tools-1.51.3.tar.gz", hash = "sha256:4fea28e3dd31871579a57058796a78093c75b74b74e9de2f2b7a7fd9a443d403"},
{file = "grpcio_tools-1.51.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:ecf1494cb695afead36995534f787761ee33fb9e116b23030113a37fe6057a83"}, {file = "grpcio_tools-1.51.3-cp310-cp310-linux_armv7l.whl", hash = "sha256:779ac1ad2258b8debaa45595bfb3814806ed8880e3ea7f194e551d76a6255969"},
{file = "grpcio_tools-1.51.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:16b8b915625dc6eb2ea7efdfb06f1fae44a9066c9016453a2ca120c034f33090"}, {file = "grpcio_tools-1.51.3-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:83bf605fe2b3591d3c8a78646f37c72c5832c4dd84b5f92405c17cb10b136be6"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:d5e033c04b416afcddd5231b3ff94a34fb5d26fba2416eb940e69b05f22cfd25"}, {file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:35f885c5afd8e6a77d320f5a9624b439a93f9be2b87fa7b7948c1ad7b2ba0894"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a218f64e667f3332b74080bdc5440aaf0fa6700ae07a0b54ecf085aaef2aa9f"}, {file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b639fb79a4d28ce494ae40e5695bf1e2cb4a05f205fc433c46b2049ab4d99"},
{file = "grpcio_tools-1.51.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b186183515ad6b8584ffe4bd820b72b00f6e7d121fb1c36294edeea9092313"}, {file = "grpcio_tools-1.51.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c6b145587d6062e2335f0b3286501dd6853a1ea50bd466a913351b7c48e5f20"},
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ccd37165d7a3e93f460096a2eb62b7a9c1ebe5c424eaee42d8e92740d0c8f6bc"}, {file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:046c0b1e372d4acf552aa0c8f5e830f019d67b75f25aeb0968d15fbdd3eaabd3"},
{file = "grpcio_tools-1.51.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:531586c5598a99658249f3c5e92826d6d2bb117abd6ffc88527d1e1d9eaef924"}, {file = "grpcio_tools-1.51.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc90b0287908c46281eb61933acaa1b96a575d0160fc98b5c64b9dec46f60d1"},
{file = "grpcio_tools-1.51.1-cp310-cp310-win32.whl", hash = "sha256:392ad4cd004f7b843cf7d916d9a15b2d6585965bfef235be1c88d8f8649777e5"}, {file = "grpcio_tools-1.51.3-cp310-cp310-win32.whl", hash = "sha256:8e9df40db7a0edd403b539cc142d6114270e35debf723a5b4a7a93d5c30fffc0"},
{file = "grpcio_tools-1.51.1-cp310-cp310-win_amd64.whl", hash = "sha256:14e82c2b3ee7e300611c2c729d411b3b911e4cca5f4ec14787457a2fb72ff9d4"}, {file = "grpcio_tools-1.51.3-cp310-cp310-win_amd64.whl", hash = "sha256:077adaee431c2b040dd77923964577087c32e828908e8fa2e53f8e003ad408c9"},
{file = "grpcio_tools-1.51.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:2281180490c475d09b7aa05dabafa5e09de9902176931e7295113f636c2b5360"}, {file = "grpcio_tools-1.51.3-cp311-cp311-linux_armv7l.whl", hash = "sha256:b50f9b8a6482a90c1a41e731a879a130f7dea267065d0a06f47c9160ce5d01c3"},
{file = "grpcio_tools-1.51.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:c4649af7f5d9553975ee66b6bfae20a84be779f13e163fa835e782961895e63c"}, {file = "grpcio_tools-1.51.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:89a68adcb4238aba69f3a364ac02c9a46e55b9e3fd8af1c6f384079abfa9347c"},
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f06bb0753b7cecbff154b523cfb8f45dee2c31b0a4c72bed7da44c57f1cba113"}, {file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:78d177da43e7f6fde6715df4a3015ae13158166bc2845ac7f9cfb526eafb41b8"},
{file = "grpcio_tools-1.51.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a671466158ed74c07ee070fb940ed783acf59ba6e6e53cb4de8fd63819c6c7f"}, {file = "grpcio_tools-1.51.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:793f9edef82f600a3324f8a3d8cd8318a8d02f28fb54f8236cbb35ce0928d186"},
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:048793747339f327ea091d8f022c6756d89713d8080dffde5ce7380cc348ea8e"}, {file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:f7583735542ced7d30baec6cc21bffeaffcec1523bf807e8f8f0047113b6d30a"},
{file = "grpcio_tools-1.51.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f6caf36e7752728329a28f93afec7c4ec9015fc1c6e4460bd1eb0f3737e1c55a"}, {file = "grpcio_tools-1.51.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f2df233a3e7db23d9b516cb5e2bfe029465f40a72978bee0584e44e7860ea73f"},
{file = "grpcio_tools-1.51.1-cp311-cp311-win32.whl", hash = "sha256:67b304282cad38642587ebae68617e450e1ad4fa1c0c8b19e9e30274dbb32716"}, {file = "grpcio_tools-1.51.3-cp311-cp311-win32.whl", hash = "sha256:7427939455735fbf2ea88c37f1585c9c8b809eec7b447642f34465eb4d26020b"},
{file = "grpcio_tools-1.51.1-cp311-cp311-win_amd64.whl", hash = "sha256:674b340f2f7bb2adbc3f15144bd37ce5ea83239f78b68dbbd0ea3cba00107e2b"}, {file = "grpcio_tools-1.51.3-cp311-cp311-win_amd64.whl", hash = "sha256:ba76d15fd149b575170fa32a1f6a9ff2b38ff9db223229a8ad6f53450a452688"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:055819992ddd30c642a7fd6f344a03747be3afa95cb910f8a2e5efaabd41cde5"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-linux_armv7l.whl", hash = "sha256:d2212c682529263b3c9e903092d0ccbb9fc6afba820e4c2fa52c2c27720cdcae"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:4e3249a2ec435b3b972610c66c8a714c188844500d564c910f57a2771dc61978"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:405656b3cf9639427e6c30a795570cba4a7c06b88a3145866f7d2c05b7e048b4"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:794f26a09b70f4f101df5cf54c6c12dc1b65747ab1dee5bda02c2991389ade56"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:3c445a064b2ef3d3475e26e2add8ddb4ac2933741ecddf71d5b071a3ad078db4"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4957f1ffa16598aa5379505fcbaeb47d65693a46b0817f4ee61db76707092aeb"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7b3374f4a6579c58d16a5fab2e6b4e9bb8625a034a7f4cd6024f4d1cc12f2a0"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9906fb6bf6d9c30c23d85153f12d130f44325afe8f9ebe58aa7a6c82ecade9d8"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e8df08b65f9379c3f103147b29542b0141ca84e77d0eee9114ca5f9b3f0d23"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87bc5f3e3698c65907d397003c64d25c3ea84e3d6aa46dac133bd98bf66835ee"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2fade12de08923b350475ca16d0d0bd68578c30fce89147aa0f94ef5759bc5a9"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a66b3a5d18a7615f0f828b72e2d2935751459c89cc4725e56bdfb3d2cd93281f"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d4ffb6325ed489065dbdca764cf37c3a29376bc657874116c9af788d7a0d2ee4"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-win32.whl", hash = "sha256:566809d9942e78821b279af70f3cf159a328127f9f3d5fee8d83ad8b2d27b2fe"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-win32.whl", hash = "sha256:f8d17271fc58ed3503dd571c79917e126deca51f85f093770a9606e806aac9dc"},
{file = "grpcio_tools-1.51.1-cp37-cp37m-win_amd64.whl", hash = "sha256:aab24a342642329de38139cb26f8492882ca0d8551bb87f6530bcc613945a0d0"}, {file = "grpcio_tools-1.51.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ef849687c7f2bd7f3277edc7c7cafc7042823d0fb078e3c01c861eb0c96ed181"},
{file = "grpcio_tools-1.51.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:6b83d7fc2597c6d392c225177d1fbbcff74900f8cc40b33236987fd1ff841330"}, {file = "grpcio_tools-1.51.3-cp38-cp38-linux_armv7l.whl", hash = "sha256:7fd18d8d211fbfd337fc12e5bdd57e62368f636addf901d290e68a39f1dfea38"},
{file = "grpcio_tools-1.51.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:79c06d2577cb4d977922bbf01234de3b20f73d1784d3cbe3179deee1bdb9a60b"}, {file = "grpcio_tools-1.51.3-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:233fc56f054424232e2086f444004413e33c699174ce6ee0e279c25227243fec"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e9abc03d67793b1bf33dc766caa69a3333f9db029869ba6e8fc6cd9c251c0080"}, {file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:867fa1973fa8b0772077c15425f122f672a18b1c53709a8a2bff9d056db4c20e"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64d8ad369417759f5fdb8ffb7cbd6374fecc06ab51c9a226dee9bbd7d311c3b5"}, {file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b486a99bdf2722e68a9d59769389e2fb86878b6f293be5111f7678e364a0c359"},
{file = "grpcio_tools-1.51.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de51a0a71845b854f6a5967756c893c96bd03e37f39e5dce87b4f409dac36ee2"}, {file = "grpcio_tools-1.51.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8bbf412c357999f88d87f421fd48b4b114fc037fec7bbaed0cb7620c24a5e44"},
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9dfe6c12b0e2c07f6a4a91a9912ef4e5bd007672533891a44e6f433ffbf7c3b1"}, {file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1166744c40821bb0aa605d2af2287fac367756f858a3d18f4c3d25bc0b92757b"},
{file = "grpcio_tools-1.51.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:27113b354f7587684eb55125733e6e5be1f489458abfe12344dabd918d8dcc54"}, {file = "grpcio_tools-1.51.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:781896c488e07b9463196045e6725e52d018cd7d0e1062d4ab1eee2647ca9170"},
{file = "grpcio_tools-1.51.1-cp38-cp38-win32.whl", hash = "sha256:98777b5031f1b3c58b688815ffa83435c103b2152c26eb144f80f4a4bb34addb"}, {file = "grpcio_tools-1.51.3-cp38-cp38-win32.whl", hash = "sha256:35c1ee7c766eb586f04ba41fa7711eb847767eb277a1737998374ac57768f1f0"},
{file = "grpcio_tools-1.51.1-cp38-cp38-win_amd64.whl", hash = "sha256:1c44b57a6770b78a1eafe355878ff1ec59a2fa07455a2cbd522c071eedae04d4"}, {file = "grpcio_tools-1.51.3-cp38-cp38-win_amd64.whl", hash = "sha256:584b201fb39307dcb1affcf2647656a0e6244423ef1659cc6caa3ff85c5ae5c1"},
{file = "grpcio_tools-1.51.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:49624394805568acd7d767dea5a00d970fca5ad8f395fe0161eeea0de5133eba"}, {file = "grpcio_tools-1.51.3-cp39-cp39-linux_armv7l.whl", hash = "sha256:e02231e21029f716a1d23a0b5e664fa243d147da33a3f55088a9529b860aa4ac"},
{file = "grpcio_tools-1.51.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:6d6626a6e4dbe843df96dc8c08dd244d2191a75324f54bfa4ebaa3e76b0b1958"}, {file = "grpcio_tools-1.51.3-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:fbb742e10bd548031b8d80f7c28eb70c7c3a9850f8e99c98cd496f19a05f9fee"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:b4fb8ed6d29f2d6cf03ef99ffaad635bbc132a59be77013691392fe557e67144"}, {file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:a836a72c657f751244cdb358c3461a89627e6d02654079d2450cfe361800428c"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8cc862a1ad30f94528d66cc6f95fb9e659005e568313e54a23550535b649573"}, {file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb554408e0ec5ff5201013f268726d9eef8e5bd1fd4b4e09c46c0b4a9de8b64c"},
{file = "grpcio_tools-1.51.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e72a30be1746ea0749a8486d0ca0120c0b2757fe84fc246a5144b1ef66d7b89"}, {file = "grpcio_tools-1.51.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:158c5bfe7e157fd9a944bde9f7dfe3b468416666e4fade77cd17caa3edc8bd81"},
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:331a897306adeec3c67470431ea8d8b4972b689d32966f94506d91f4dac20952"}, {file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:715c792679224171c0584e9f235b921d76f8990deb38b0d1215d0469301d9cd9"},
{file = "grpcio_tools-1.51.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f336ad9be661d92fa45940e74e8ff3d78e67ebe9b4f7ea8774b2d680c17aeb6c"}, {file = "grpcio_tools-1.51.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ece44f42b10e0bceb49235be1e361e1ee69afee7f938c82fb656a601a4a720e3"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"}, {file = "grpcio_tools-1.51.3-cp39-cp39-win32.whl", hash = "sha256:980e632710ba05e04364c6f276e905d5d367437f1ce2265ce7b96b5c1eac5693"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"}, {file = "grpcio_tools-1.51.3-cp39-cp39-win_amd64.whl", hash = "sha256:5f4c47b14e66f80365cd5667ecc2f7fb0eb91e02c4e54362041b758feaa00511"},
]
hf-transfer = [
{file = "hf_transfer-0.1.2-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:2b9189a4a460646ee135ee771f39c0f695d3d5bf08b7ff1dcfe374227520e994"},
{file = "hf_transfer-0.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:654fcaba4e7084caa1e97430982ea968935a72916ee0f4afc60e356f89774099"},
{file = "hf_transfer-0.1.2-cp310-none-win_amd64.whl", hash = "sha256:eb29e7b3707b5cac02e689c89111685ebcdaa3cebba02eb7ac1b0f076357da72"},
{file = "hf_transfer-0.1.2-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0bfca9bd84e925e978a0f157df488704c17a0b9ad240b2859262faba0c74cd40"},
{file = "hf_transfer-0.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d00c5473b35227b2f113fd43ff13cbac9539f2e6779fa0680a887b0aac31c389"},
{file = "hf_transfer-0.1.2-cp311-none-win_amd64.whl", hash = "sha256:1aaf5937aa433b7d09ce5bf60967ec22b7d3982957b00516a8dc2aaa66384372"},
{file = "hf_transfer-0.1.2-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:b0aa760a55995ad59ea17e395babafdc56c4e664be0c2d2055664199dd913da1"},
{file = "hf_transfer-0.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:889dd15e8472daf66e266eb056e31a485af3c35f95a483bb43489a0f6e44c359"},
{file = "hf_transfer-0.1.2-cp37-none-win_amd64.whl", hash = "sha256:30df586e18ec8a8e67e3201b9038210d94cb3c03c1cbd97673b9c78ede227178"},
{file = "hf_transfer-0.1.2-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc97eb97f929f96bed896cd3af9bbdf121c15ac6d63524b9fc9312fd2929099a"},
{file = "hf_transfer-0.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:583c2c80210a60dafed9a81ba50c389878aee6c34b2dd375cd84522658f29ad8"},
{file = "hf_transfer-0.1.2-cp38-none-win_amd64.whl", hash = "sha256:6dff58f50d1435b0346f31a32f1f9e2301986521c1d0b51e47a3c82b96d02156"},
{file = "hf_transfer-0.1.2-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d6db1a8f539133f7a893bb32721916fe72b4d2aa3eb7604581ba1f03b8167c90"},
{file = "hf_transfer-0.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f284e3f775d215c9a8d3d1c6f6b1001b1e7990d73ae5fd9aea6c9bce9ea79285"},
{file = "hf_transfer-0.1.2-cp39-none-win_amd64.whl", hash = "sha256:8625beabebc582eafc4141a5ecb9f1183b728d4f63767f01fdcf1e2fbafe6d43"},
{file = "hf_transfer-0.1.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947dd1b8b22ac10723b2887ed4b5ef929f7d4dd850b0d66c0c6954a9a85afb06"},
{file = "hf_transfer-0.1.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90a020f41dfae4629186c284888cd5adbebe402e2497a88351416ab93c7df9a8"},
{file = "hf_transfer-0.1.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5eb89698746a29805bfc60126b9a008e6ba08a82ef9bb122a6544e84f748e8a4"},
{file = "hf_transfer-0.1.2.tar.gz", hash = "sha256:6bf847f4c19c7d8d9f9bbb8a7ed52e1271bbf0c1bd920357db0c274ccc69f21d"},
] ]
idna = [ idna = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
@ -965,20 +994,19 @@ pluggy = [
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
] ]
protobuf = [ protobuf = [
{file = "protobuf-4.21.12-cp310-abi3-win32.whl", hash = "sha256:b135410244ebe777db80298297a97fbb4c862c881b4403b71bac9d4107d61fd1"}, {file = "protobuf-4.22.0-cp310-abi3-win32.whl", hash = "sha256:b2fea9dc8e3c0f32c38124790ef16cba2ee0628fe2022a52e435e1117bfef9b1"},
{file = "protobuf-4.21.12-cp310-abi3-win_amd64.whl", hash = "sha256:89f9149e4a0169cddfc44c74f230d7743002e3aa0b9472d8c28f0388102fc4c2"}, {file = "protobuf-4.22.0-cp310-abi3-win_amd64.whl", hash = "sha256:a33a273d21852f911b8bda47f39f4383fe7c061eb1814db2c76c9875c89c2491"},
{file = "protobuf-4.21.12-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:299ea899484ee6f44604deb71f424234f654606b983cb496ea2a53e3c63ab791"}, {file = "protobuf-4.22.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:e894e9ae603e963f0842498c4cd5d39c6a60f0d7e4c103df50ee939564298658"},
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:d1736130bce8cf131ac7957fa26880ca19227d4ad68b4888b3be0dea1f95df97"}, {file = "protobuf-4.22.0-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:7c535d126e7dcc714105ab20b418c4fedbd28f8b8afc42b7350b1e317bbbcc71"},
{file = "protobuf-4.21.12-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:78a28c9fa223998472886c77042e9b9afb6fe4242bd2a2a5aced88e3f4422aa7"}, {file = "protobuf-4.22.0-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:86c3d20428b007537ba6792b475c0853bba7f66b1f60e610d913b77d94b486e4"},
{file = "protobuf-4.21.12-cp37-cp37m-win32.whl", hash = "sha256:3d164928ff0727d97022957c2b849250ca0e64777ee31efd7d6de2e07c494717"}, {file = "protobuf-4.22.0-cp37-cp37m-win32.whl", hash = "sha256:1669cb7524221a8e2d9008d0842453dbefdd0fcdd64d67672f657244867635fb"},
{file = "protobuf-4.21.12-cp37-cp37m-win_amd64.whl", hash = "sha256:f45460f9ee70a0ec1b6694c6e4e348ad2019275680bd68a1d9314b8c7e01e574"}, {file = "protobuf-4.22.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ab4d043865dd04e6b09386981fe8f80b39a1e46139fb4a3c206229d6b9f36ff6"},
{file = "protobuf-4.21.12-cp38-cp38-win32.whl", hash = "sha256:6ab80df09e3208f742c98443b6166bcb70d65f52cfeb67357d52032ea1ae9bec"}, {file = "protobuf-4.22.0-cp38-cp38-win32.whl", hash = "sha256:29288813aacaa302afa2381db1d6e0482165737b0afdf2811df5fa99185c457b"},
{file = "protobuf-4.21.12-cp38-cp38-win_amd64.whl", hash = "sha256:1f22ac0ca65bb70a876060d96d914dae09ac98d114294f77584b0d2644fa9c30"}, {file = "protobuf-4.22.0-cp38-cp38-win_amd64.whl", hash = "sha256:e474b63bab0a2ea32a7b26a4d8eec59e33e709321e5e16fb66e766b61b82a95e"},
{file = "protobuf-4.21.12-cp39-cp39-win32.whl", hash = "sha256:27f4d15021da6d2b706ddc3860fac0a5ddaba34ab679dc182b60a8bb4e1121cc"}, {file = "protobuf-4.22.0-cp39-cp39-win32.whl", hash = "sha256:47d31bdf58222dd296976aa1646c68c6ee80b96d22e0a3c336c9174e253fd35e"},
{file = "protobuf-4.21.12-cp39-cp39-win_amd64.whl", hash = "sha256:237216c3326d46808a9f7c26fd1bd4b20015fb6867dc5d263a493ef9a539293b"}, {file = "protobuf-4.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:c27f371f0159feb70e6ea52ed7e768b3f3a4c5676c1900a7e51a24740381650e"},
{file = "protobuf-4.21.12-py2.py3-none-any.whl", hash = "sha256:a53fd3f03e578553623272dc46ac2f189de23862e68565e83dde203d41b76fc5"}, {file = "protobuf-4.22.0-py3-none-any.whl", hash = "sha256:c3325803095fb4c2a48649c321d2fbde59f8fbfcb9bfc7a86df27d112831c571"},
{file = "protobuf-4.21.12-py3-none-any.whl", hash = "sha256:b98d0148f84e3a3c569e19f52103ca1feacdac0d2df8d6533cf983d1fda28462"}, {file = "protobuf-4.22.0.tar.gz", hash = "sha256:652d8dfece122a24d98eebfef30e31e455d300efa41999d1182e015984ac5930"},
{file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
] ]
psutil = [ psutil = [
{file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"}, {file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"},
@ -1089,8 +1117,8 @@ safetensors = [
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"}, {file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
] ]
setuptools = [ setuptools = [
{file = "setuptools-67.2.0-py3-none-any.whl", hash = "sha256:16ccf598aab3b506593c17378473978908a2734d7336755a8769b480906bec1c"}, {file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
{file = "setuptools-67.2.0.tar.gz", hash = "sha256:b440ee5f7e607bb8c9de15259dba2583dd41a38879a7abc1d43a71c59524da48"}, {file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
] ]
tomli = [ tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
@ -1124,8 +1152,8 @@ typer = [
{file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"},
] ]
typing-extensions = [ typing-extensions = [
{file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"},
{file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"},
] ]
urllib3 = [ urllib3 = [
{file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"}, {file = "urllib3-1.26.14-py2.py3-none-any.whl", hash = "sha256:75edcdc2f7d85b137124a6c3c9fc3933cdeaa12ecb9a6a959f22797a0feca7e1"},
@ -1140,68 +1168,79 @@ win32-setctime = [
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
] ]
wrapt = [ wrapt = [
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, {file = "wrapt-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:96e25c8603a155559231c19c0349245eeb4ac0096fe3c1d0be5c47e075bd4f46"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87"}, {file = "wrapt-1.15.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:40737a081d7497efea35ab9304b829b857f21558acfc7b3272f908d33b0d9d4c"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1"}, {file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:f87ec75864c37c4c6cb908d282e1969e79763e0d9becdfe9fe5473b7bb1e5f09"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b"}, {file = "wrapt-1.15.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1286eb30261894e4c70d124d44b7fd07825340869945c79d05bda53a40caa079"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462"}, {file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:493d389a2b63c88ad56cdc35d0fa5752daac56ca755805b1b0c530f785767d5e"},
{file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1"}, {file = "wrapt-1.15.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:58d7a75d731e8c63614222bcb21dd992b4ab01a399f1f09dd82af17bbfc2368a"},
{file = "wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320"}, {file = "wrapt-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:21f6d9a0d5b3a207cdf7acf8e58d7d13d463e639f0c7e01d82cdb671e6cb7923"},
{file = "wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2"}, {file = "wrapt-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce42618f67741d4697684e501ef02f29e758a123aa2d669e2d964ff734ee00ee"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4"}, {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41d07d029dd4157ae27beab04d22b8e261eddfc6ecd64ff7000b10dc8b3a5727"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069"}, {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54accd4b8bc202966bafafd16e69da9d5640ff92389d33d28555c5fd4f25ccb7"},
{file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310"}, {file = "wrapt-1.15.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fbfbca668dd15b744418265a9607baa970c347eefd0db6a518aaf0cfbd153c0"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f"}, {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:76e9c727a874b4856d11a32fb0b389afc61ce8aaf281ada613713ddeadd1cfec"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656"}, {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e20076a211cd6f9b44a6be58f7eeafa7ab5720eb796975d0c03f05b47d89eb90"},
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, {file = "wrapt-1.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a74d56552ddbde46c246b5b89199cb3fd182f9c346c784e1a93e4dc3f5ec9975"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, {file = "wrapt-1.15.0-cp310-cp310-win32.whl", hash = "sha256:26458da5653aa5b3d8dc8b24192f574a58984c749401f98fff994d41d3f08da1"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, {file = "wrapt-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:75760a47c06b5974aa5e01949bf7e66d2af4d08cb8c1d6516af5e39595397f5e"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, {file = "wrapt-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba1711cda2d30634a7e452fc79eabcadaffedf241ff206db2ee93dd2c89a60e7"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, {file = "wrapt-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56374914b132c702aa9aa9959c550004b8847148f95e1b824772d453ac204a72"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a89ce3fd220ff144bd9d54da333ec0de0399b52c9ac3d2ce34b569cf1a5748fb"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d"}, {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bbe623731d03b186b3d6b0d6f51865bf598587c38d6f7b0be2e27414f7f214e"},
{file = "wrapt-1.14.1-cp35-cp35m-win32.whl", hash = "sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7"}, {file = "wrapt-1.15.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3abbe948c3cbde2689370a262a8d04e32ec2dd4f27103669a45c6929bcdbfe7c"},
{file = "wrapt-1.14.1-cp35-cp35m-win_amd64.whl", hash = "sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00"}, {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b67b819628e3b748fd3c2192c15fb951f549d0f47c0449af0764d7647302fda3"},
{file = "wrapt-1.14.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4"}, {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7eebcdbe3677e58dd4c0e03b4f2cfa346ed4049687d839adad68cc38bb559c92"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1"}, {file = "wrapt-1.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74934ebd71950e3db69960a7da29204f89624dde411afbfb3b4858c1409b1e98"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1"}, {file = "wrapt-1.15.0-cp311-cp311-win32.whl", hash = "sha256:bd84395aab8e4d36263cd1b9308cd504f6cf713b7d6d3ce25ea55670baec5416"},
{file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff"}, {file = "wrapt-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:a487f72a25904e2b4bbc0817ce7a8de94363bd7e79890510174da9d901c38705"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d"}, {file = "wrapt-1.15.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:4ff0d20f2e670800d3ed2b220d40984162089a6e2c9646fdb09b85e6f9a8fc29"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1"}, {file = "wrapt-1.15.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:9ed6aa0726b9b60911f4aed8ec5b8dd7bf3491476015819f56473ffaef8959bd"},
{file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569"}, {file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:896689fddba4f23ef7c718279e42f8834041a21342d95e56922e1c10c0cc7afb"},
{file = "wrapt-1.14.1-cp36-cp36m-win32.whl", hash = "sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed"}, {file = "wrapt-1.15.0-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:75669d77bb2c071333417617a235324a1618dba66f82a750362eccbe5b61d248"},
{file = "wrapt-1.14.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471"}, {file = "wrapt-1.15.0-cp35-cp35m-win32.whl", hash = "sha256:fbec11614dba0424ca72f4e8ba3c420dba07b4a7c206c8c8e4e73f2e98f4c559"},
{file = "wrapt-1.14.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248"}, {file = "wrapt-1.15.0-cp35-cp35m-win_amd64.whl", hash = "sha256:fd69666217b62fa5d7c6aa88e507493a34dec4fa20c5bd925e4bc12fce586639"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68"}, {file = "wrapt-1.15.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b0724f05c396b0a4c36a3226c31648385deb6a65d8992644c12a4963c70326ba"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d"}, {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbeccb1aa40ab88cd29e6c7d8585582c99548f55f9b2581dfc5ba68c59a85752"},
{file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77"}, {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38adf7198f8f154502883242f9fe7333ab05a5b02de7d83aa2d88ea621f13364"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7"}, {file = "wrapt-1.15.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:578383d740457fa790fdf85e6d346fda1416a40549fe8db08e5e9bd281c6a475"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015"}, {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:a4cbb9ff5795cd66f0066bdf5947f170f5d63a9274f99bdbca02fd973adcf2a8"},
{file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a"}, {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:af5bd9ccb188f6a5fdda9f1f09d9f4c86cc8a539bd48a0bfdc97723970348418"},
{file = "wrapt-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853"}, {file = "wrapt-1.15.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:b56d5519e470d3f2fe4aa7585f0632b060d532d0696c5bdfb5e8319e1d0f69a2"},
{file = "wrapt-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c"}, {file = "wrapt-1.15.0-cp36-cp36m-win32.whl", hash = "sha256:77d4c1b881076c3ba173484dfa53d3582c1c8ff1f914c6461ab70c8428b796c1"},
{file = "wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456"}, {file = "wrapt-1.15.0-cp36-cp36m-win_amd64.whl", hash = "sha256:077ff0d1f9d9e4ce6476c1a924a3332452c1406e59d90a2cf24aeb29eeac9420"},
{file = "wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f"}, {file = "wrapt-1.15.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5c5aa28df055697d7c37d2099a7bc09f559d5053c3349b1ad0c39000e611d317"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc"}, {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a8564f283394634a7a7054b7983e47dbf39c07712d7b177b37e03f2467a024e"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1"}, {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780c82a41dc493b62fc5884fb1d3a3b81106642c5c5c78d6a0d4cbe96d62ba7e"},
{file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af"}, {file = "wrapt-1.15.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e169e957c33576f47e21864cf3fc9ff47c223a4ebca8960079b8bd36cb014fd0"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b"}, {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b02f21c1e2074943312d03d243ac4388319f2456576b2c6023041c4d57cd7019"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0"}, {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f2e69b3ed24544b0d3dbe2c5c0ba5153ce50dcebb576fdc4696d52aa22db6034"},
{file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57"}, {file = "wrapt-1.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d787272ed958a05b2c86311d3a4135d3c2aeea4fc655705f074130aa57d71653"},
{file = "wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5"}, {file = "wrapt-1.15.0-cp37-cp37m-win32.whl", hash = "sha256:02fce1852f755f44f95af51f69d22e45080102e9d00258053b79367d07af39c0"},
{file = "wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d"}, {file = "wrapt-1.15.0-cp37-cp37m-win_amd64.whl", hash = "sha256:abd52a09d03adf9c763d706df707c343293d5d106aea53483e0ec8d9e310ad5e"},
{file = "wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383"}, {file = "wrapt-1.15.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cdb4f085756c96a3af04e6eca7f08b1345e94b53af8921b25c72f096e704e145"},
{file = "wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7"}, {file = "wrapt-1.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:230ae493696a371f1dbffaad3dafbb742a4d27a0afd2b1aecebe52b740167e7f"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86"}, {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63424c681923b9f3bfbc5e3205aafe790904053d42ddcc08542181a30a7a51bd"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735"}, {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6bcbfc99f55655c3d93feb7ef3800bd5bbe963a755687cbf1f490a71fb7794b"},
{file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b"}, {file = "wrapt-1.15.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99f4309f5145b93eca6e35ac1a988f0dc0a7ccf9ccdcd78d3c0adf57224e62f"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3"}, {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b130fe77361d6771ecf5a219d8e0817d61b236b7d8b37cc045172e574ed219e6"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3"}, {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:96177eb5645b1c6985f5c11d03fc2dbda9ad24ec0f3a46dcce91445747e15094"},
{file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe"}, {file = "wrapt-1.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5fe3e099cf07d0fb5a1e23d399e5d4d1ca3e6dfcbe5c8570ccff3e9208274f7"},
{file = "wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5"}, {file = "wrapt-1.15.0-cp38-cp38-win32.whl", hash = "sha256:abd8f36c99512755b8456047b7be10372fca271bf1467a1caa88db991e7c421b"},
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"}, {file = "wrapt-1.15.0-cp38-cp38-win_amd64.whl", hash = "sha256:b06fa97478a5f478fb05e1980980a7cdf2712015493b44d0c87606c1513ed5b1"},
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, {file = "wrapt-1.15.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2e51de54d4fb8fb50d6ee8327f9828306a959ae394d3e01a1ba8b2f937747d86"},
{file = "wrapt-1.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0970ddb69bba00670e58955f8019bec4a42d1785db3faa043c33d81de2bf843c"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76407ab327158c510f44ded207e2f76b657303e17cb7a572ffe2f5a8a48aa04d"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd525e0e52a5ff16653a3fc9e3dd827981917d34996600bbc34c05d048ca35cc"},
{file = "wrapt-1.15.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d37ac69edc5614b90516807de32d08cb8e7b12260a285ee330955604ed9dd29"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:078e2a1a86544e644a68422f881c48b84fef6d18f8c7a957ffd3f2e0a74a0d4a"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2cf56d0e237280baed46f0b5316661da892565ff58309d4d2ed7dba763d984b8"},
{file = "wrapt-1.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7dc0713bf81287a00516ef43137273b23ee414fe41a3c14be10dd95ed98a2df9"},
{file = "wrapt-1.15.0-cp39-cp39-win32.whl", hash = "sha256:46ed616d5fb42f98630ed70c3529541408166c22cdfd4540b88d5f21006b0eff"},
{file = "wrapt-1.15.0-cp39-cp39-win_amd64.whl", hash = "sha256:eef4d64c650f33347c1f9266fa5ae001440b232ad9b98f1f43dfe7a79435c0a6"},
{file = "wrapt-1.15.0-py3-none-any.whl", hash = "sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640"},
{file = "wrapt-1.15.0.tar.gz", hash = "sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a"},
] ]

View File

@ -1,11 +1,11 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation-server"
version = "0.2.1" version = "0.4.0"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app' text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
@ -22,6 +22,7 @@ loguru = "^0.6.0"
opentelemetry-api = "^1.15.0" opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
[tool.poetry.extras] [tool.poetry.extras]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]

View File

@ -1,6 +1,6 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@pytest.fixture @pytest.fixture
@ -10,6 +10,7 @@ def default_pb_parameters():
repetition_penalty=1.0, repetition_penalty=1.0,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
typical_p=1.0,
do_sample=False, do_sample=False,
) )

View File

@ -4,9 +4,9 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -65,8 +64,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.input_ids[0][-1] == 10264 assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3) assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][-1] == 1 assert batch.attention_mask[0][0] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0) assert torch.all(batch.attention_mask[0][1:] == 0)
assert batch.past_key_values is None assert batch.past_key_values is None
@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert batch.size == default_pb_batch.size assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch): def test_batch_concatenate_no_prefill(default_bloom_batch):
@ -98,22 +97,19 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert not next_batch.keys_head_dim_last assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
assert ( assert len(next_batch.all_input_ids[0]) == sequence_length + 1
len(next_batch.all_input_ids[0]) assert len(next_batch.attention_mask[0]) == 11
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3) assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][-2:] == 1) assert torch.all(next_batch.attention_mask[0][:2] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 10264 assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
@ -213,15 +209,19 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1) assert torch.all(
assert torch.all(next_batch.attention_mask[1:, -2:] == 1) next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
assert torch.all(next_batch.attention_mask[1:, :-2] == 0) )
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264) assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2] assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -4,8 +4,8 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -62,8 +61,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.input_ids[0][-1] == 14402 assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256) assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0][-1] == 1 assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0) assert torch.all(batch.attention_mask[0, 1:] == 0)
assert batch.past_key_values is None assert batch.past_key_values is None
@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert batch.size == default_pb_batch.size assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0] assert batch.max_input_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch): def test_batch_concatenate_no_prefill(default_causal_lm_batch):
@ -94,23 +93,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert isinstance(next_batch, CausalLMBatch) assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size assert len(next_batch.all_input_ids) == next_batch.size
assert ( assert len(next_batch.all_input_ids[0]) == sequence_length + 1
len(next_batch.all_input_ids[0]) assert len(next_batch.attention_mask[0]) == 11
== len(next_batch.attention_mask[0])
== sequence_length + 1
)
assert next_batch.all_input_ids[0][-1] == 13 assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-2] == 14402 assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][-2:] == 1) assert torch.all(next_batch.attention_mask[0][0:2] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0) assert torch.all(next_batch.attention_mask[0][2:] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1) assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 13 assert next_batch.input_ids[0, 0] == 13
assert next_batch.input_lengths == [2] assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0] assert next_batch.max_input_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None assert next_batch.past_key_values is not None
assert all( assert all(
@ -210,16 +206,20 @@ def test_batch_concatenate(
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1) assert torch.all(
assert torch.all(next_batch.attention_mask[1:, -2:] == 1) next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
assert torch.all(next_batch.attention_mask[1:, :-2] == 0) )
assert torch.all(
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
)
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert next_batch.input_ids[0, 0] == 12355 assert next_batch.input_ids[0, 0] == 12355
assert torch.all(next_batch.input_ids[1:] == 13) assert torch.all(next_batch.input_ids[1:] == 13)
assert next_batch.input_lengths == [3, 2, 2] assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3 assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0] assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests assert next_batch.requests[1:] == next_batch_1.requests

View File

@ -1,8 +1,8 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="def", inputs="def",
input_length=1,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>", inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_length=5,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_length=2,
parameters=default_pb_parameters, parameters=default_pb_parameters,
stopping_parameters=default_pb_stop_parameters, stopping_parameters=default_pb_stop_parameters,
) )
@ -106,7 +105,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert len(generations) == len(next_batch) assert len(generations) == len(next_batch)
assert isinstance(next_batch, Seq2SeqLMBatch) assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) assert next_batch.input_ids is None
assert torch.equal( assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
) )
@ -148,7 +147,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations]) assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations]) assert all([generation.token_text == " " for generation in generations])
assert generations[0].request_id == 0 assert generations[0].request_id == 0
@ -220,11 +219,6 @@ def test_batch_concatenate(
assert next_batch.batch_id == 0 assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids[:, 0] == 4268)
assert torch.all(next_batch.input_ids[:, 1] == 1)
assert torch.all(next_batch.attention_mask == 1)
assert torch.equal( assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
) )
@ -233,9 +227,10 @@ def test_batch_concatenate(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
) )
assert torch.all(next_batch.decoder_attention_mask[0] == 1) assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0) assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1) assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
assert torch.equal( assert torch.equal(
next_batch.encoder_last_hidden_state[0], next_batch.encoder_last_hidden_state[0],

View File

@ -0,0 +1,21 @@
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation_server.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])

View File

@ -0,0 +1,40 @@
import pytest
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,14 +1,6 @@
import pytest from text_generation_server.utils.tokens import (
from huggingface_hub.utils import RevisionNotFoundError
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
LocalEntryNotFoundError,
FinishReason, FinishReason,
) )
@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,68 +0,0 @@
import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
):
utils.download_weights(model_id, revision, extension)
if __name__ == "__main__":
app()

View File

@ -1,24 +0,0 @@
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.device = device
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError

View File

@ -1,283 +0,0 @@
import concurrent
import os
import re
import torch
import torch.distributed
from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, _CACHED_NO_EXIST
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
TopKLogitsWarper,
)
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device
) -> "NextTokenChooser":
return NextTokenChooser(
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames
def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return _CACHED_NO_EXIST
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return str(cached_file) if cached_file.is_file() else None
def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_id, revision, extension)
download_function = partial(
hf_hub_download,
repo_id=model_id,
local_files_only=False,
)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_function, filename=filename, revision=revision)
for filename in filenames
]
files = [
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)

View File

@ -0,0 +1,115 @@
import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
):
if sharded:
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
logger_level: str = "INFO",
json_output: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info(
"Files are already present in the local cache. " "Skipping download."
)
return
# Local files not found
except utils.LocalEntryNotFoundError:
pass
# Download weights directly
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
except utils.EntryNotFoundError as e:
if not extension == ".safetensors":
raise e
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
if __name__ == "__main__":
app()

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig from transformers import AutoConfig
from typing import Optional from typing import Optional
from text_generation.models.model import Model from text_generation_server.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
__all__ = [ __all__ = [
"Model", "Model",
@ -19,7 +19,6 @@ __all__ = [
"CausalLM", "CausalLM",
"Galactica", "Galactica",
"GalacticaSharded", "GalacticaSharded",
"GPTNeox",
"GPTNeoxSharded", "GPTNeoxSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
@ -41,6 +40,15 @@ torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
if "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
@ -48,24 +56,19 @@ def get_model(
return BLOOMSharded(model_id, revision, quantize=quantize) return BLOOMSharded(model_id, revision, quantize=quantize)
else: else:
return BLOOM(model_id, revision, quantize=quantize) return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox":
if config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize) return GPTNeoxSharded(model_id, revision, quantize=quantize)
else: else:
return GPTNeox(model_id, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
elif config.model_type == "t5":
if config.model_type == "t5":
if sharded: if sharded:
return T5Sharded(model_id, revision, quantize=quantize) return T5Sharded(model_id, revision, quantize=quantize)
else: else:
return Seq2SeqLM(model_id, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)
elif model_id.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
else:
return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_id:
return SantaCoder(model_id, revision, quantize)
else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
try: try:

View File

@ -17,13 +17,12 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -59,9 +58,6 @@ class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -80,14 +76,8 @@ class BLOOMSharded(BLOOM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
# Only download weights for small models
if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 Batch,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -36,7 +41,8 @@ class CausalLMBatch(Batch):
# Metadata used for padding # Metadata used for padding
size: int size: int
max_sequence_length: int max_input_length: int
padding_right_offset: int
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
@ -61,22 +67,36 @@ class CausalLMBatch(Batch):
input_lengths = [] input_lengths = []
# Parse batch # Parse batch
padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
) )
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -84,24 +104,30 @@ class CausalLMBatch(Batch):
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
input_ids=tokenized_inputs["input_ids"], input_ids=input_ids,
attention_mask=tokenized_inputs["attention_mask"], attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max(input_lengths), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
) )
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = 0
max_sequence_length = max(batch.max_sequence_length for batch in batches) max_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_input_length = max(max_input_length, batch.max_input_length)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
@ -144,13 +170,24 @@ class CausalLMBatch(Batch):
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_sequence_length), (total_batch_size, max_input_length + padding_right_offset),
) )
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[ attention_mask[
start_index:end_index, -batch.max_sequence_length : start_index:end_index,
] = batch.attention_mask[:, -batch.max_sequence_length :] left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor # Create empty tensor
# position_ids is always of shape [batch_size, 1] # position_ids is always of shape [batch_size, 1]
@ -172,7 +209,7 @@ class CausalLMBatch(Batch):
padded_past_values_shape = ( padded_past_values_shape = (
total_batch_size, total_batch_size,
num_heads, num_heads,
max_sequence_length - 1, max_input_length - 1,
head_dim, head_dim,
) )
@ -184,7 +221,7 @@ class CausalLMBatch(Batch):
total_batch_size, total_batch_size,
num_heads, num_heads,
head_dim, head_dim,
max_sequence_length - 1, max_input_length - 1,
) )
# This will run only once per layer # This will run only once per layer
@ -198,20 +235,20 @@ class CausalLMBatch(Batch):
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_input_length - 1) :,
:, :,
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
else: else:
past_key_values[j][0][ past_key_values[j][0][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] ] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
past_key_values[j][1][ past_key_values[j][1][
start_index:end_index, :, -(batch.max_sequence_length - 1) :, : start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] ] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += batch.size start_index += batch.size
@ -227,7 +264,8 @@ class CausalLMBatch(Batch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
max_sequence_length=max_sequence_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
) )
@ -294,9 +332,12 @@ class CausalLM(Model):
def generate_token( def generate_token(
self, batch: CausalLMBatch self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]: ) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, attention_mask,
batch.position_ids, batch.position_ids,
batch.past_key_values, batch.past_key_values,
) )
@ -311,7 +352,7 @@ class CausalLM(Model):
# Metadata # Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_input_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
@ -347,10 +388,8 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode( next_token_text = self.decode_token(
next_token_id_squeezed, next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -381,8 +420,8 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids.append(all_input_ids)
next_batch_size += 1 next_batch_size += 1
next_batch_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_input_length = max(
next_batch_max_sequence_length, new_input_length next_batch_max_input_length, new_input_length
) )
# Prefill # Prefill
@ -409,6 +448,7 @@ class CausalLM(Model):
next_token_id_squeezed, next_token_id_squeezed,
next_token_logprob, next_token_logprob,
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
) )
@ -448,14 +488,8 @@ class CausalLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
next_batch_attention_mask = torch.cat( next_batch_attention_mask[:, -batch.padding_right_offset] = 1
[
next_batch_attention_mask,
next_batch_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
# Update position_ids # Update position_ids
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1 next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
@ -472,7 +506,8 @@ class CausalLM(Model):
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length, max_input_length=next_batch_max_input_length,
padding_right_offset=batch.padding_right_offset - 1,
keys_head_dim_last=batch.keys_head_dim_last, keys_head_dim_last=batch.keys_head_dim_last,
) )
return generations, next_batch return generations, next_batch

View File

@ -2,7 +2,7 @@ import re
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type from typing import List, Optional, Type, Tuple
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -18,15 +18,14 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -97,24 +96,37 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths = [] input_lengths = []
# Parse batch # Parse batch
max_sequence_length = 0
padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
@ -122,8 +134,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
input_ids=tokenized_inputs["input_ids"], input_ids=input_ids,
attention_mask=tokenized_inputs["attention_mask"], attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=None, past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
@ -131,7 +143,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max(input_lengths), max_sequence_length=max_sequence_length,
padding_right_offset=padding_right_offset,
) )
@ -146,14 +159,25 @@ class Galactica(CausalLM):
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
) )
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -172,14 +196,8 @@ class GalacticaSharded(Galactica):
) )
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master and model_id == "facebook/galactica-125m":
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
@ -329,7 +347,6 @@ class GalacticaSharded(Galactica):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple from typing import List, Optional
from accelerate import init_empty_weights from accelerate import init_empty_weights
from safetensors import safe_open from safetensors import safe_open
@ -16,11 +16,10 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -31,23 +30,7 @@ except Exception as e:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
class GPTNeox(CausalLM): class GPTNeoxSharded(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GPTNeoxSharded(GPTNeox):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
@ -69,14 +52,8 @@ class GPTNeoxSharded(GPTNeox):
model_id, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
@ -231,6 +208,7 @@ class GPTNeoxSharded(GPTNeox):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )

View File

@ -0,0 +1,43 @@
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device
# see `decode_token` method
self.tokenizer.add_special_tokens(
{"additional_special_tokens": ["<decode-token>"]}
)
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
"<decode-token>"
)
self.special_decode_token_length = len("<decode-token>")
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError
def decode_token(self, token_id: int) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both
result = self.tokenizer.decode(
[self.special_decode_token_id, token_id], skip_special_tokens=False
)
# slice to remove special decode token
return result[self.special_decode_token_length :]

View File

@ -1,10 +1,10 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import Optional, List, Tuple from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>" FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>" FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 GeneratedText,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -42,9 +47,10 @@ class Seq2SeqLMBatch(Batch):
size: int size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -58,36 +64,41 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = []
decoder_input_ids = [] decoder_input_ids = []
decoder_input_lengths = [] decoder_input_lengths = []
# Parse batch # Parse batch
padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token # Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append( stopping_criteria = StoppingCriteria.from_pb(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=False, return_token_type_ids=False,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
@ -100,13 +111,14 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_last_hidden_state=None, encoder_last_hidden_state=None,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths, decoder_input_lengths=decoder_input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=len(pb.requests), size=len(pb.requests),
max_input_length=max(input_lengths), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
) )
@classmethod @classmethod
@ -115,11 +127,17 @@ class Seq2SeqLMBatch(Batch):
"""Concatenate multiple batches together by padding internal torch tensors""" """Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = 0
max_input_length = max(batch.max_input_length for batch in batches) max_input_length = 0
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += batch.size
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max( max_decoder_input_length = max(
batch.max_decoder_input_length for batch in batches max_decoder_input_length, batch.max_decoder_input_length
) )
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
requests = [] requests = []
@ -129,7 +147,6 @@ class Seq2SeqLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
# Batch tensors # Batch tensors
input_ids = None
attention_mask = None attention_mask = None
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
@ -155,16 +172,6 @@ class Seq2SeqLMBatch(Batch):
if batch.encoder_last_hidden_state is None: if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None") raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if input_ids is None:
input_ids = batch.input_ids.new_zeros(
(total_batch_size, max_input_length),
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor # Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros( attention_mask = batch.attention_mask.new_zeros(
@ -189,19 +196,30 @@ class Seq2SeqLMBatch(Batch):
if decoder_attention_mask is None: if decoder_attention_mask is None:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros( decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length + padding_right_offset),
) )
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`. # this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset = max_decoder_input_length - batch.max_decoder_input_length
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index,
left_offset:-padding_right_offset,
] = 1 ] = 1
# If it exists, we need to index # If it exists, we need to index
else: else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length
- batch.padding_right_offset
)
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index,
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] left_offset:-padding_right_offset,
] = batch.decoder_attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create padded tensor # Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
@ -273,7 +291,7 @@ class Seq2SeqLMBatch(Batch):
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
input_ids=input_ids, input_ids=None,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
@ -286,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
size=total_batch_size, size=total_batch_size,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
) )
def __len__(self): def __len__(self):
@ -326,7 +345,9 @@ class Seq2SeqLM(Model):
return Seq2SeqLMBatch return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str: def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True) return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward( def forward(
self, self,
@ -342,14 +363,6 @@ class Seq2SeqLM(Model):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -369,12 +382,34 @@ class Seq2SeqLM(Model):
def generate_token( def generate_token(
self, batch: Seq2SeqLMBatch self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]: ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
:, : -batch.padding_right_offset
]
else:
decoder_attention_mask = None
# check if first forward or not
if batch.past_key_values is not None:
# Only take the last token
decoder_input_ids = batch.decoder_input_ids[:, -1].unsqueeze(-1)
else:
decoder_input_ids = batch.decoder_input_ids
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = batch.encoder_last_hidden_state
logits, encoder_last_hidden_state, past = self.forward( logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.decoder_input_ids, decoder_input_ids,
batch.decoder_attention_mask, decoder_attention_mask,
batch.encoder_last_hidden_state, encoder_last_hidden_state,
batch.past_key_values, batch.past_key_values,
) )
@ -402,7 +437,6 @@ class Seq2SeqLM(Model):
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.input_ids,
batch.decoder_input_ids, batch.decoder_input_ids,
) )
@ -414,7 +448,6 @@ class Seq2SeqLM(Model):
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
input_tokens,
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
@ -429,10 +462,8 @@ class Seq2SeqLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text = self.tokenizer.decode( next_token_text = self.decode_token(
next_token_id_squeezed, next_token_id_squeezed,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
@ -469,14 +500,10 @@ class Seq2SeqLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts [self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -487,6 +514,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed, next_token_id_squeezed,
next_token_logprob, next_token_logprob,
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
) )
@ -500,10 +528,8 @@ class Seq2SeqLM(Model):
# If we finished at least one generation, we need to evict the indices of the generations that finished # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch # from the values of the next batch
if len(next_batch_keep_indices) != len(batch): if len(next_batch_keep_indices) != len(batch):
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to decoder_attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
if batch.decoder_attention_mask is not None: if batch.decoder_attention_mask is not None:
next_batch_decoder_attention_mask = batch.decoder_attention_mask[ next_batch_decoder_attention_mask = batch.decoder_attention_mask[
next_batch_keep_indices next_batch_keep_indices
@ -526,7 +552,6 @@ class Seq2SeqLM(Model):
batch.stopping_criterias[i] for i in next_batch_keep_indices batch.stopping_criterias[i] for i in next_batch_keep_indices
] ]
else: else:
next_batch_input_ids = batch.input_ids
next_batch_attention_mask = batch.attention_mask next_batch_attention_mask = batch.attention_mask
next_batch_decoder_attention_mask = batch.decoder_attention_mask next_batch_decoder_attention_mask = batch.decoder_attention_mask
next_batch_encoder_last_hidden_state = encoder_last_hidden_state next_batch_encoder_last_hidden_state = encoder_last_hidden_state
@ -536,20 +561,14 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update decoder_attention_mask with padding as we added a new token to input_ids # Update decoder_attention_mask as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None: if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat( next_batch_decoder_attention_mask[:, -batch.padding_right_offset] = 1
[
next_batch_decoder_attention_mask,
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
],
dim=1,
)
next_batch = Seq2SeqLMBatch( next_batch = Seq2SeqLMBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
input_ids=next_batch_input_ids, input_ids=None,
attention_mask=next_batch_attention_mask, attention_mask=next_batch_attention_mask,
decoder_input_ids=next_batch_decoder_input_ids, decoder_input_ids=next_batch_decoder_input_ids,
decoder_attention_mask=next_batch_decoder_attention_mask, decoder_attention_mask=next_batch_decoder_attention_mask,
@ -562,5 +581,6 @@ class Seq2SeqLM(Model):
size=next_batch_size, size=next_batch_size,
max_input_length=next_batch_max_input_length, max_input_length=next_batch_max_input_length,
max_decoder_input_length=next_batch_max_decoder_input_length, max_decoder_input_length=next_batch_max_decoder_input_length,
padding_right_offset=batch.padding_right_offset - 1,
) )
return generations, next_batch return generations, next_batch

View File

@ -16,11 +16,10 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import Seq2SeqLM from text_generation_server.models import Seq2SeqLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
) )
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
# Only master download weights
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config) model = AutoModelForSeq2SeqLM.from_config(config)
@ -228,14 +221,6 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):
@ -73,6 +73,7 @@ class Generation:
token_id: int token_id: int
token_logprob: float token_logprob: float
token_text: str token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
@ -84,6 +85,7 @@ class Generation:
token_id=self.token_id, token_id=self.token_id,
token_logprob=self.token_logprob, token_logprob=self.token_logprob,
token_text=self.token_text, token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from text_generation.cache import Cache from text_generation_server.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -0,0 +1,36 @@
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
]

View File

@ -0,0 +1,94 @@
import concurrent
import time
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
logger.info(f"Convert {pt_file} to {st_file}.")
pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state:
pt_state = pt_state["state_dict"]
remove_shared_pointers(pt_state)
# Tensors need to be contiguous
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
st_file.parent.mkdir(parents=True, exist_ok=True)
save_file(pt_state, str(st_file), metadata={"format": "pt"})
# Check that both files are close in size
check_file_size(pt_file, st_file)
# Load safetensors state
st_state = load_file(str(st_file))
for k in st_state:
pt_tensor = pt_state[k]
st_tensor = st_state[k]
if not torch.equal(pt_tensor, st_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], st_files: List[Path]):
assert len(pt_files) == len(st_files)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(convert_file, pt_file=pt_file, st_file=st_file)
for pt_file, st_file in zip(pt_files, st_files)
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
for i, future in enumerate(concurrent.futures.as_completed(futures)):
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Convert: [{i + 1}/{len(futures)}] -- ETA: {eta}")

View File

@ -0,0 +1,35 @@
import os
import torch
from datetime import timedelta
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert world_size <= torch.cuda.device_count(), "Each process is one gpu"
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size

View File

@ -0,0 +1,165 @@
import time
import os
from datetime import timedelta
from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
no_exist_dir = repo_cache / ".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if file is cached as "no_exist"
if (no_exist_dir / revision / filename).is_file():
return None
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
return cached_file if cached_file.is_file() else None
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
return list(Path(model_id).glob(f"*{extension}"))
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise LocalEntryNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(filename):
local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
return Path(local_file)
logger.info(f"Download file: {filename}")
start_time = time.time()
local_file = hf_hub_download(
filename=filename,
repo_id=model_id,
revision=revision,
local_files_only=False,
)
logger.info(
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return Path(local_file)
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
files = []
for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(file)
return files

View File

@ -0,0 +1,160 @@
import re
import torch
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
RepetitionPenaltyLogitsProcessor,
PreTrainedTokenizerBase,
)
from typing import List, Tuple, Optional
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device)
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens
class Greedy:
def __call__(self, logits):
return logits.argmax()
class NextTokenChooser:
def __init__(
self,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
top_k=None,
top_p=None,
typical_p=None,
do_sample=False,
seed=0,
device="cpu",
):
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling = do_sample
if watermark:
warpers.append(WatermarkLogitsProcessor(device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
warpers.append(TemperatureLogitsWarper(temperature))
sampling = True
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k))
sampling = True
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p))
sampling = True
self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
# Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)
# Choose tokens
next_id = self.choice(scores[-1])
return next_id.view(1, 1), logprobs
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
self.regex = re.compile(f".*{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
)

View File

@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from transformers import LogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.gamma = gamma
self.delta = delta
self.rng = torch.Generator(device=device)
self.hash_key = hash_key
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
assert (
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int
) -> list[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(
max_value, device=input_ids.device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@staticmethod
def _calc_greenlist_mask(
scores: torch.FloatTensor, greenlist_token_ids
) -> torch.BoolTensor:
green_tokens_mask = torch.zeros_like(scores)
green_tokens_mask[-1, greenlist_token_ids] = 1
final_mask = green_tokens_mask.bool()
return final_mask
@staticmethod
def _bias_greenlist_logits(
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
) -> torch.Tensor:
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
return scores
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)
scores = self._bias_greenlist_logits(
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
)
return scores

9
supported_models.json Normal file
View File

@ -0,0 +1,9 @@
[
"bigscience/bloom",
"bigscience/bloomz",
"EleutherAI/gpt-neox-20b",
"google/flan-ul2",
"google/flan-t5-xxl",
"OpenAssistant/oasst-sft-1-pythia-12b",
"olivierdehaene/optimized-santacoder"
]