diff --git a/.dockerignore b/.dockerignore index 38e8f8243..5aa1aa3a4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,2 @@ aml -router/target \ No newline at end of file +target \ No newline at end of file diff --git a/.gitignore b/.gitignore index 723ef36f4..ec376bb82 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -.idea \ No newline at end of file +.idea +target \ No newline at end of file diff --git a/router/Cargo.lock b/Cargo.lock similarity index 91% rename from router/Cargo.lock rename to Cargo.lock index eda429087..551f7aebc 100644 --- a/router/Cargo.lock +++ b/Cargo.lock @@ -55,9 +55,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c" dependencies = [ "proc-macro2", "quote", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "byteorder" @@ -255,9 +255,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.0.15" +version = "4.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f" +checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267" dependencies = [ "atty", "bitflags", @@ -391,6 +391,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctrlc" +version = "3.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173" +dependencies = [ + "nix", + "winapi", +] + [[package]] name = "darling" version = "0.10.2" @@ -529,7 +539,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -936,9 +946,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754" +checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" [[package]] name = "js-sys" @@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.134" +version = "0.2.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" +checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c" [[package]] name = "lock_api" @@ -1047,7 +1057,7 @@ dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1074,6 +1084,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb" +dependencies = [ + "autocfg", + "bitflags", + "cfg-if", + "libc", +] + [[package]] name = "nom" version = "7.1.1" @@ -1084,6 +1106,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num_cpus" version = "1.13.1" @@ -1185,6 +1217,12 @@ version = "6.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1197,15 +1235,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1300,9 +1338,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" dependencies = [ "unicode-ident", ] @@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" dependencies = [ "lazy_static", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1584,9 +1622,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" +checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" dependencies = [ "itoa", "ryu", @@ -1625,6 +1663,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.7" @@ -1681,10 +1728,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] -name = "syn" -version = "1.0.101" +name = "subprocess" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2" +checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "syn" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" dependencies = [ "proc-macro2", "quote", @@ -1741,13 +1798,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "text-generation-launcher" +version = "0.1.0" +dependencies = [ + "clap 4.0.17", + "ctrlc", + "subprocess", + "tracing", + "tracing-subscriber", +] + [[package]] name = "text-generation-router" version = "0.1.0" dependencies = [ "axum", "bloom-inference-client", - "clap 4.0.15", + "clap 4.0.17", "futures", "parking_lot", "serde", @@ -1872,6 +1940,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "winapi", @@ -1910,9 +1979,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af" +checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" dependencies = [ "futures-core", "pin-project-lite", @@ -2031,9 +2100,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-service" @@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.36" +version = "0.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", "log", @@ -2056,9 +2125,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", @@ -2067,9 +2136,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" dependencies = [ "once_cell", "valuable", @@ -2108,11 +2177,11 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b" +checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" dependencies = [ - "ansi_term", + "nu-ansi-term", "sharded-slab", "smallvec", "thread_local", @@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" [[package]] name = "unicode-ident" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" +checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" [[package]] name = "unicode-normalization" @@ -2361,43 +2430,100 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", ] +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + [[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + [[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + [[package]] name = "winreg" version = "0.10.1" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..d3f7dfb09 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,11 @@ +[workspace] +members = [ + "router", + "router/client", + "launcher" +] + +[profile.release] +debug = 1 +incremental = true +lto = "off" diff --git a/Dockerfile b/Dockerfile index 9b8a20549..70b91bc21 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.64 as builder +FROM rust:1.64 as router-builder WORKDIR /usr/src @@ -9,7 +9,17 @@ WORKDIR /usr/src/router RUN cargo install --path . -FROM nvidia/cuda:11.6.1-devel-ubuntu18.04 +FROM rust:1.64 as launcher-builder + +WORKDIR /usr/src + +COPY launcher launcher + +WORKDIR /usr/src/launcher + +RUN cargo install --path . + +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 ENV LANG=C.UTF-8 \ LC_ALL=C.UTF-8 \ @@ -34,17 +44,15 @@ RUN cd ~ && \ bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \ conda create -n text-generation python=3.9 -y +WORKDIR /usr/src + +COPY server/Makefile server/Makefile + # Install specific version of torch -RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir +RUN cd server && make install-torch # Install specific version of transformers -RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \ - unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \ - rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \ - cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \ - /opt/miniconda/envs/text-generation/bin/python setup.py install - -WORKDIR /usr/src +RUN cd server && make install-transformers # Install server COPY server server @@ -52,9 +60,7 @@ RUN cd server && \ /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir # Install router -COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router +COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router +COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher -COPY run.sh . -RUN chmod +x run.sh - -CMD ["./run.sh"] \ No newline at end of file +CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..3a80a12db --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +install-server: + cd server && make pip-install + +install-router: + cd router && cargo install --path . + +install-launcher: + cd launcher && cargo install --path . + +install: + make install-server + make install-router + make install-launcher + +run-bloom-560m: + text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2 + +run-bloom: + text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8 diff --git a/README.md b/README.md index 6d23d9c5b..18e9865de 100644 --- a/README.md +++ b/README.md @@ -1,50 +1,51 @@ -# Text Generation Inference +# LLM Text Generation Inference -A Rust and gRPC server for text generation inference. +
-## Load Tests +![architecture](assets/architecture.jpg) + +
+ +A Rust and gRPC server for large language models text generation inference. + +## Load Tests for BLOOM See `k6/load_test.js` -We send the default examples with a 1 second delay between each request. +We send the default examples with a 1 second delay between requests. Stages: -- Ramp up to 50 concurrent requests per second in 1min -- Ramp up from 50 to 100 concurrent requests per second in 2min -- Ramp down to 0 concurrent requests per second in 1min +- Ramp up to 50 vus in 1min +- Ramp up from 50 to 100 vus in 2min +- Ramp down to 0 vus in 1min -| | avg | min | med | max | p(90) | p(95) | RPS | -|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------| -| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | -| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 | -| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | +| | avg | min | med | max | p(90) | p(95) | RPS | +|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------| +| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | +| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 | +| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | ## Install ```shell -cd server -pip install . +make install ``` -``` -cd router -cargo build --release -``` - -## Run +## Run ```shell -python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models +make run-bloom-560m ``` +## Test + ```shell -./router/target/release/router +curl 127.0.0.1:3000/generate \ + -X POST \ + -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ + -H 'Content-Type: application/json' ``` ## TODO: -- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated -- [ ] Add tests -- [ ] Add shutdown logic in router and server -- [ ] Improve multi-processing logic in server -- [ ] Improve past key layer indexing? \ No newline at end of file +- [ ] Add tests for the `server/model` logic \ No newline at end of file diff --git a/aml/deployment.yaml b/aml/deployment.yaml index be28ceef2..35d19006f 100644 --- a/aml/deployment.yaml +++ b/aml/deployment.yaml @@ -8,7 +8,7 @@ environment_variables: MODEL_NAME: bigscience/bloom NUM_GPUS: 8 environment: - image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3 + image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2 inference_config: liveness_route: port: 3000 diff --git a/assets/architecture.jpg b/assets/architecture.jpg new file mode 100644 index 000000000..e0a5f7c6f Binary files /dev/null and b/assets/architecture.jpg differ diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml new file mode 100644 index 000000000..ae3558b1c --- /dev/null +++ b/launcher/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "text-generation-launcher" +version = "0.1.0" +edition = "2021" +authors = ["Olivier Dehaene"] +description = "Text Generation Launcher" + +[dependencies] +clap = { version = "4.0.15", features = ["derive", "env"] } +ctrlc = "3.2.3" +subprocess = "0.2.9" +tracing = "0.1.37" +tracing-subscriber = "0.3.16" diff --git a/launcher/src/main.rs b/launcher/src/main.rs new file mode 100644 index 000000000..0821a456a --- /dev/null +++ b/launcher/src/main.rs @@ -0,0 +1,358 @@ +use clap::Parser; +use std::io::{BufRead, BufReader, Read}; +use std::path::Path; +use std::process::ExitCode; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::TryRecvError; +use std::sync::Arc; +use std::sync::{mpsc, Mutex}; +use std::thread; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use std::{fs, io}; +use subprocess::{Popen, PopenConfig, PopenError, Redirection}; + +/// App Configuration +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(default_value = "bigscience/bloom-560m", long, env)] + model_name: String, + #[clap(long, env)] + num_shard: Option, + #[clap(long, env)] + shard_directory: Option, + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "1000", long, env)] + max_input_length: usize, + #[clap(default_value = "32", long, env)] + max_batch_size: usize, + #[clap(default_value = "5", long, env)] + max_waiting_time: u64, + #[clap(default_value = "3000", long, short, env)] + port: u16, + #[clap(default_value = "/tmp/text-generation-server", long, env)] + shard_uds_path: String, + #[clap(default_value = "localhost", long, env)] + master_addr: String, + #[clap(default_value = "29500", long, env)] + master_port: usize, +} + +fn main() -> ExitCode { + tracing_subscriber::fmt::init(); + + // Pattern match configuration + let Args { + model_name, + num_shard, + shard_directory, + max_concurrent_requests, + max_input_length, + max_batch_size, + max_waiting_time, + port, + shard_uds_path, + master_addr, + master_port, + } = Args::parse(); + + // By default we only have one master shard + let num_shard = num_shard.unwrap_or(1); + + // Signal handler + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + // Shared shutdown bool + let shutdown = Arc::new(Mutex::new(false)); + // Shared shutdown channel + // When shutting down, the main thread will wait for all senders to be dropped + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + + // Shared channel to track shard status + let (status_sender, status_receiver) = mpsc::channel(); + + // Start shard processes + for rank in 0..num_shard { + let model_name = model_name.clone(); + let uds_path = shard_uds_path.clone(); + let shard_directory = shard_directory.clone(); + let master_addr = master_addr.clone(); + let status_sender = status_sender.clone(); + let shutdown = shutdown.clone(); + let shutdown_sender = shutdown_sender.clone(); + thread::spawn(move || { + shard_manager( + model_name, + uds_path, + shard_directory, + rank, + num_shard, + master_addr, + master_port, + status_sender, + shutdown, + shutdown_sender, + ) + }); + } + drop(shutdown_sender); + + // Wait for shard to start + let mut shard_ready = 0; + while running.load(Ordering::SeqCst) { + match status_receiver.try_recv() { + Ok(ShardStatus::Ready) => { + shard_ready += 1; + if shard_ready == num_shard { + break; + } + } + Err(TryRecvError::Empty) => { + sleep(Duration::from_millis(100)); + } + Ok(ShardStatus::Failed((rank, err))) => { + tracing::error!("Shard {} failed to start:\n{}", rank, err); + shutdown_shards(shutdown, &shutdown_receiver); + return ExitCode::FAILURE; + } + Err(TryRecvError::Disconnected) => { + tracing::error!("Shard status channel disconnected"); + shutdown_shards(shutdown, &shutdown_receiver); + return ExitCode::FAILURE; + } + } + } + + // We might have received a termination signal + if !running.load(Ordering::SeqCst) { + shutdown_shards(shutdown, &shutdown_receiver); + return ExitCode::SUCCESS; + } + + // All shard started + // Start webserver + tracing::info!("Starting Webserver"); + let mut webserver = match Popen::create( + &[ + "text-generation-router", + "--max-concurrent-requests", + &max_concurrent_requests.to_string(), + "--max-input-length", + &max_input_length.to_string(), + "--max-batch-size", + &max_batch_size.to_string(), + "--max-waiting-time", + &max_waiting_time.to_string(), + "--port", + &port.to_string(), + "--master-shard-uds-path", + &format!("{}-0", shard_uds_path), + "--tokenizer-name", + &model_name, + ], + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + tracing::error!("Failed to start webserver: {}", err); + if let PopenError::IoError(err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("text-generation-router not found in PATH"); + tracing::error!("Please install it with `make install-router`") + } + } + + shutdown_shards(shutdown, &shutdown_receiver); + return ExitCode::FAILURE; + } + }; + + // Redirect STDOUT and STDERR to the console + let webserver_stdout = webserver.stdout.take().unwrap(); + let webserver_stderr = webserver.stderr.take().unwrap(); + + thread::spawn(move || { + let stdout = BufReader::new(webserver_stdout); + let stderr = BufReader::new(webserver_stderr); + for line in stdout.lines() { + println!("{}", line.unwrap()); + } + for line in stderr.lines() { + println!("{}", line.unwrap()); + } + }); + + // Default exit code + let mut exit_code = ExitCode::SUCCESS; + + while running.load(Ordering::SeqCst) { + if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { + tracing::error!("Shard {} failed:\n{}", rank, err); + exit_code = ExitCode::FAILURE; + break; + }; + + match webserver.poll() { + Some(_) => { + tracing::error!("Webserver Crashed"); + shutdown_shards(shutdown, &shutdown_receiver); + return ExitCode::FAILURE; + } + None => { + sleep(Duration::from_millis(100)); + } + }; + } + + // Graceful termination + webserver.terminate().unwrap(); + tracing::info!("Waiting for webserver to gracefully shutdown"); + webserver.wait_timeout(Duration::from_secs(90)).unwrap(); + tracing::info!("Webserver terminated"); + shutdown_shards(shutdown, &shutdown_receiver); + + exit_code +} + +#[derive(Debug)] +enum ShardStatus { + Ready, + Failed((usize, String)), +} + +#[allow(clippy::too_many_arguments)] +fn shard_manager( + model_name: String, + uds_path: String, + shard_directory: Option, + rank: usize, + world_size: usize, + master_addr: String, + master_port: usize, + status_sender: mpsc::Sender, + shutdown: Arc>, + _shutdown_sender: mpsc::Sender<()>, +) { + // Get UDS path + let uds_string = format!("{}-{}", uds_path, rank); + let uds = Path::new(&uds_string); + // Clean previous runs + fs::remove_file(uds).unwrap_or_default(); + + // Process args + let mut shard_argv = vec![ + "bloom-inference-server".to_string(), + "serve".to_string(), + model_name, + "--uds-path".to_string(), + uds_path, + ]; + + if world_size > 1 { + shard_argv.push("--sharded".to_string()); + } + + if let Some(shard_directory) = shard_directory { + shard_argv.push("--shard-directory".to_string()); + shard_argv.push(shard_directory); + } + + // Start process + tracing::info!("Starting shard {}", rank); + let mut p = match Popen::create( + &shard_argv, + PopenConfig { + stdout: Redirection::Pipe, + stderr: Redirection::Pipe, + // Needed for the shutdown procedure + setpgid: true, + // NCCL env vars + env: Some(vec![ + ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), + ( + "WORLD_SIZE".parse().unwrap(), + world_size.to_string().parse().unwrap(), + ), + ("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()), + ( + "MASTER_PORT".parse().unwrap(), + master_port.to_string().parse().unwrap(), + ), + ]), + ..Default::default() + }, + ) { + Ok(p) => p, + Err(err) => { + if let PopenError::IoError(ref err) = err { + if err.kind() == io::ErrorKind::NotFound { + tracing::error!("bloom-inference-server not found in PATH"); + tracing::error!("Please install it with `make install-server`") + } + } + status_sender + .send(ShardStatus::Failed((rank, err.to_string()))) + .unwrap(); + return; + } + }; + + let mut ready = false; + let start_time = Instant::now(); + loop { + // Process exited + if p.poll().is_some() { + let mut err = String::new(); + p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); + status_sender + .send(ShardStatus::Failed((rank, err))) + .unwrap(); + return; + } + + // We received a shutdown signal + if *shutdown.lock().unwrap() { + p.terminate().unwrap(); + let _ = p.wait_timeout(Duration::from_secs(90)); + tracing::info!("Shard {} terminated", rank); + return; + } + + // Shard is ready + if uds.exists() && !ready { + tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed()); + status_sender.send(ShardStatus::Ready).unwrap(); + ready = true; + } else if !ready { + tracing::info!("Waiting for shard {} to be ready...", rank); + } + sleep(Duration::from_secs(5)); + } +} + +fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receiver<()>) { + tracing::info!("Shutting down shards"); + // Update shutdown value to true + // This will be picked up by the shard manager + { + let mut shutdown = shutdown.lock().unwrap(); + *shutdown = true; + } + + // Wait for shards to shutdown + // This will block till all shutdown_sender are dropped + let _ = shutdown_receiver.recv(); +} diff --git a/proto/generate.proto b/proto/generate.proto index 8c5221b48..45afca82a 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -11,10 +11,6 @@ service TextGenerationService { rpc Generate (GenerateRequest) returns (GenerateResponse); /// Generate tokens for a list of cached batches rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); - /// Generate tokens until the text of at least one request of the batch is generated - rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse); - /// Generate tokens until the text of at least one request of the cached batches i finished - rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse); } /// Empty request @@ -92,27 +88,3 @@ message GenerateWithCacheResponse { /// Next batch (cached) optional Batch batch = 2; } - -message GenerateUntilFinishedRequest { - /// Batch - Batch batch = 1; -} - -message GenerateUntilFinishedResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; - /// Next batch (cached) - optional Batch batch = 2; -} - -message GenerateUntilFinishedWithCacheRequest { - /// Cached batches - repeated Batch batches = 1; -} - -message GenerateUntilFinishedWithCacheResponse { - /// Finished requests - repeated GeneratedText generated_texts = 1; - /// Next batch (cached) - optional Batch batch = 2; -} diff --git a/router/.gitignore b/router/.gitignore deleted file mode 100644 index ea8c4bf7f..000000000 --- a/router/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/target diff --git a/router/Cargo.toml b/router/Cargo.toml index 37f319e99..5820c1383 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,16 +22,7 @@ serde = "1.0.145" serde_json = "1.0.85" thiserror = "1.0.37" tokenizers = "0.13.0" -tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] } +tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tracing = "0.1.36" tracing-subscriber = "0.3.15" -[workspace] -members = [ - "client", -] - -[profile.release] -debug = 1 -incremental = true -lto = "off" diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index 7760c8cbe..633f82a9d 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -5,8 +5,6 @@ edition = "2021" [dependencies] futures = "^0.3" -#grpc-error-details = { path = "../../grpc-error-details" } -#grpc-metadata = { path = "../../grpc-metadata" } prost = "^0.9" thiserror = "^1.0" tokio = { version = "^1.21", features = ["sync"] } diff --git a/router/client/src/client.rs b/router/client/src/client.rs index e7189b892..172d0bf73 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,10 +1,11 @@ +/// Single shard Client use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::*; use crate::Result; use tonic::transport::{Channel, Uri}; use tracing::*; -/// BLOOM Inference gRPC client +/// Text Generation Inference gRPC client #[derive(Clone)] pub struct Client { stub: TextGenerationServiceClient, @@ -34,6 +35,7 @@ impl Client { }) } + /// Returns a list of uris or unix sockets of all shards #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { let request = tonic::Request::new(ServiceDiscoveryRequest {}); @@ -46,6 +48,7 @@ impl Client { .into_inner() .urls .into_iter() + // Remove unix socket prefix .map(|url| match url.strip_prefix("unix://") { None => url, Some(stripped_url) => stripped_url.to_string(), @@ -54,6 +57,7 @@ impl Client { Ok(urls) } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self) -> Result<()> { let request = tonic::Request::new(ClearCacheRequest {}); @@ -64,6 +68,10 @@ impl Client { Ok(()) } + /// Generate one token for each request in the given batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch #[instrument(skip(self))] pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); @@ -76,6 +84,10 @@ impl Client { Ok((response.generated_texts, response.batch)) } + /// Generate one token for each request in the given cached batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch #[instrument(skip(self))] pub async fn generate_with_cache( &mut self, @@ -90,34 +102,4 @@ impl Client { .into_inner(); Ok((response.generated_texts, response.batch)) } - - #[instrument(skip(self))] - pub async fn generate_until_finished( - &mut self, - batch: Batch, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) }); - let response = self - .stub - .generate_until_finished(request) - .instrument(info_span!("generate_until_finished")) - .await? - .into_inner(); - Ok((response.generated_texts, response.batch)) - } - - #[instrument(skip(self))] - pub async fn generate_until_finished_with_cache( - &mut self, - batches: Vec, - ) -> Result<(Vec, Option)> { - let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches }); - let response = self - .stub - .generate_until_finished_with_cache(request) - .instrument(info_span!("generate_until_finished_with_cache")) - .await? - .into_inner(); - Ok((response.generated_texts, response.batch)) - } } diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 48b2650d0..0f1f96bca 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -1,6 +1,7 @@ -//! BLOOM Inference gRPC client library +//! Text Generation gRPC client library mod client; +#[allow(clippy::derive_partial_eq_without_eq)] mod pb; mod sharded_client; @@ -8,7 +9,7 @@ pub use client::Client; pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; pub use sharded_client::ShardedClient; use thiserror::Error; -pub use tonic::transport; +use tonic::transport; use tonic::Status; #[derive(Error, Debug, Clone)] @@ -21,7 +22,7 @@ pub enum ClientError { impl From for ClientError { fn from(err: Status) -> Self { - Self::Generation(err.to_string()) + Self::Generation(err.message().to_string()) } } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 7134551e6..916e72b48 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,9 +1,11 @@ +/// Multi shard Client use crate::Result; use crate::{Batch, Client, GeneratedText}; use futures::future::join_all; use tokio::sync::{broadcast, mpsc}; use tonic::transport::Uri; +/// List of all available commands that can be sent through the command channel #[derive(Clone, Debug)] enum Command { Generate( @@ -14,36 +16,32 @@ enum Command { Vec, mpsc::Sender, Option)>>, ), - GenerateUntilFinished( - Batch, - mpsc::Sender, Option)>>, - ), - GenerateUntilFinishedWithCache( - Vec, - mpsc::Sender, Option)>>, - ), ClearCache(mpsc::Sender>), } +/// Tokio task that handles the communication with a single shard +/// +/// We subscribe on a broadcast channel to receive commands that will be sent by +/// the ShardedClient. +/// +/// Each command is fan out to all shards. +/// +/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi +/// producer = the shards, single consumer = the ShardedClient). async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver) { while let Ok(message) = request_subscriber.recv().await { match message { Command::Generate(batch, response_tx) => { let result = client.generate(batch).await; + // We can unwrap_or(()) here because the only error that can happen is if the + // receiver is dropped, which means that the ShardedClient already received a + // response from another shard response_tx.try_send(result).unwrap_or(()); } Command::GenerateWithCache(batches, response_tx) => { let result = client.generate_with_cache(batches).await; response_tx.try_send(result).unwrap_or(()); } - Command::GenerateUntilFinished(batch, response_tx) => { - let result = client.generate_until_finished(batch).await; - response_tx.try_send(result).unwrap_or(()); - } - Command::GenerateUntilFinishedWithCache(batches, response_tx) => { - let result = client.generate_until_finished_with_cache(batches).await; - response_tx.try_send(result).unwrap_or(()); - } Command::ClearCache(response_tx) => { let result = client.clear_cache().await; response_tx.try_send(result).unwrap_or(()); @@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece } } +/// Text Generation Inference gRPC multi client pub struct ShardedClient { + _clients: Vec, request_tx: broadcast::Sender, } impl ShardedClient { - fn new(mut clients: Vec) -> Self { + fn new(clients: Vec) -> Self { + // The broadcast channel to communicate with the shards + // We use a capacity of one as the shards are not asynchronous and can only process one + // command at a time let (request_tx, _) = broadcast::channel(1); - for client in clients.drain(..) { + // Spawn client tasks + for client in clients.iter() { let request_subscriber = request_tx.subscribe(); - tokio::spawn(client_task(client, request_subscriber)); + tokio::spawn(client_task(client.clone(), request_subscriber)); } - Self { request_tx } + Self { + _clients: clients, + request_tx, + } } + /// Create a new ShardedClient from a master client. The master client will communicate with + /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method. async fn from_master_client(mut master_client: Client) -> Result { + // Get all uris/unix sockets from the master client let uris = master_client.service_discovery().await.unwrap(); - let futures = uris.into_iter().map(|path| Client::connect_uds(path)); + let futures = uris.into_iter().map(Client::connect_uds); let clients: Result> = join_all(futures).await.into_iter().collect(); Ok(Self::new(clients?)) } - /// Returns a client connected to the given url + /// Returns a client connected to the given uri pub async fn connect(uri: Uri) -> Result { let master_client = Client::connect(uri).await?; Self::from_master_client(master_client).await @@ -87,51 +97,43 @@ impl ShardedClient { Self::from_master_client(master_client).await } + /// Generate one token for each request in the given batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch pub async fn generate(&self, batch: Batch) -> Result<(Vec, Option)> { + // Create a channel to receive the response from the shards + // We will only ever receive one message on this channel let (response_tx, mut response_rx) = mpsc::channel(1); self.request_tx .send(Command::Generate(batch, response_tx)) .unwrap(); + // As soon as we receive one response, we can return as all shards will return the same response_rx.recv().await.unwrap() } + /// Generate one token for each request in the given cached batch + /// + /// Returns a list of generated texts of request that met their stopping criteria + /// and the next cached batch pub async fn generate_with_cache( &self, batches: Vec, ) -> Result<(Vec, Option)> { + // Create a channel to receive the response from the shards + // We will only ever receive one message on this channel let (response_tx, mut response_rx) = mpsc::channel(1); self.request_tx .send(Command::GenerateWithCache(batches, response_tx)) .unwrap(); + // As soon as we receive one response, we can return as all shards will return the same response_rx.recv().await.unwrap() } - pub async fn generate_until_finished( - &self, - batch: Batch, - ) -> Result<(Vec, Option)> { - let (response_tx, mut response_rx) = mpsc::channel(1); - self.request_tx - .send(Command::GenerateUntilFinished(batch, response_tx)) - .unwrap(); - response_rx.recv().await.unwrap() - } - - pub async fn generate_until_finished_with_cache( - &self, - batches: Vec, - ) -> Result<(Vec, Option)> { - let (response_tx, mut response_rx) = mpsc::channel(1); - self.request_tx - .send(Command::GenerateUntilFinishedWithCache( - batches, - response_tx, - )) - .unwrap(); - response_rx.recv().await.unwrap() - } - + /// Clear the past generations cache pub async fn clear_cache(&self) -> Result<()> { + // Create a channel to receive the response from the shards + // We will only ever receive one message on this channel let (response_tx, mut response_rx) = mpsc::channel(1); self.request_tx .send(Command::ClearCache(response_tx)) diff --git a/router/src/batcher.rs b/router/src/batcher.rs index ebd817300..4523dbfff 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,129 +1,158 @@ -use crate::server::GenerateRequest; +/// Batching and inference logic +use crate::GenerateRequest; use crate::{Db, Entry}; use axum::http::StatusCode; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use tokio::sync::{oneshot, Notify}; +use tokio::time::Instant; +use tracing::instrument; -const MAX_LENGTH: usize = 128; - -#[derive(Debug, Error)] -pub enum InferError { - #[error("Request failed during generation: {0}")] - GenerationError(String), - #[error("Model is overloaded")] - Overloaded, -} - -impl From for (StatusCode, String) { - fn from(err: InferError) -> Self { - match err { - InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), - InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()), - } - } -} - +/// Batcher #[derive(Clone)] pub struct Batcher { + /// Request database db: Db, + /// Shared state shared: Arc, } +/// Batcher shared state struct Shared { + /// Batching background Tokio task notifier batching_task: Notify, } impl Batcher { - pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self { + pub(crate) fn new( + client: ShardedClient, + max_batch_size: usize, + max_waiting_time: Duration, + ) -> Self { + // Batcher shared state let db = Db::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), }); - tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone())); + // Spawn batching background task that contains all the inference logic + tokio::spawn(batching_task( + max_batch_size, + max_waiting_time, + client, + db.clone(), + shared.clone(), + )); Self { db, shared } } + /// Add a new request to the database and return a future that will generate the text pub(crate) async fn infer( &self, input_length: usize, request: GenerateRequest, ) -> Result { - if self.db.len() > MAX_LENGTH { - return Err(InferError::Overloaded); - } - let (request_tx, request_rx) = oneshot::channel(); + // One shot channel to communicate with the background batching task + let (response_tx, response_rx) = oneshot::channel(); + + // Try to append the request to the database self.db.append(Entry { request, - response_tx: request_tx, + response_tx, input_length, + time: Instant::now(), }); + + // Notify the background task that we have a new entry in the database that needs + // to be batched self.shared.batching_task.notify_waiters(); - match request_rx.await.unwrap() { + + // Await on the response from the background task + // We can safely unwrap as the background task will never drop the sender + match response_rx.await.unwrap() { Ok(output) => Ok(output), Err(err) => Err(InferError::GenerationError(err.to_string())), } } } -async fn batching_task(max_batch_size: usize, - client: ShardedClient, - db: Db, - shared: Arc) { +/// Batching logic +/// Will be launched in a background Tokio task +/// +/// Batches requests and sends them to the inference server +#[instrument(skip(client, db, shared))] +async fn batching_task( + max_batch_size: usize, + max_waiting_time: Duration, + client: ShardedClient, + db: Db, + shared: Arc, +) { + // Minimum batch size after which we try to add more requests let limit_min_batch_size = (max_batch_size / 2) as u32; + // Infinite loop loop { + // Wait for a notification from the Batcher struct shared.batching_task.notified().await; - if let Some(batch) = db.next_batch(max_batch_size) { - let request_ids = batch.requests.iter().map(|req| req.id).collect(); - let mut cached_batch = match batch.size { - size if size > limit_min_batch_size => { - wrap_future(client.generate_until_finished(batch), request_ids, &db).await - } - _ => wrap_future(client.generate(batch), request_ids, &db).await, - }; + // Get the next batch from the DB + // This batch might be smaller than the maximum batch size if there are not enough requests + // waiting in the DB + if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) { + let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; + // We loop until we do not receive any cached batch from the inference server (== until + // all requests have met their stopping criteria) while let Some(batch) = cached_batch { - let mut current_batch_size = batch.size; + // Get current batch info + let batch_size = batch.size; let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); let mut batches = vec![batch]; - if current_batch_size <= limit_min_batch_size { - if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) { - let new_batch_request_ids = - new_batch.requests.iter().map(|req| req.id).collect(); + // If the current batch is too small, we try to add more requests to it + if batch_size <= limit_min_batch_size { + // Get the next batch from the DB that meet our minimum size criteria + if let Some((new_request_ids, new_batch)) = + db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None) + { + // Generate one token for this new batch to have the attention past in cache let new_cached_batch = - wrap_future(client.generate(new_batch), new_batch_request_ids, &db) - .await; + wrap_future(client.generate(new_batch), new_request_ids, &db).await; + // Extend current batch with the new batch + if let Some(new_cached_batch) = new_cached_batch { + request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); + batches.push(new_cached_batch); + } + } + // If we don't have enough requests to meet the minimum size criteria, we + // try to get the next batch from the DB that have been waiting over + // the max_waiting_time + else if let Some((new_request_ids, new_batch)) = + db.next_batch(None, max_batch_size, Some(max_waiting_time)) + { + let new_cached_batch = + wrap_future(client.generate(new_batch), new_request_ids, &db).await; + // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { - current_batch_size += new_cached_batch.size; request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); batches.push(new_cached_batch); } } } - cached_batch = match current_batch_size { - size if size > limit_min_batch_size => { - wrap_future( - client.generate_until_finished_with_cache(batches), - request_ids, - &db, - ) - .await - } - _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await, - }; + cached_batch = + wrap_future(client.generate_with_cache(batches), request_ids, &db).await; } } } } +/// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, request_ids: Vec, @@ -134,6 +163,7 @@ async fn wrap_future( send_generated(generated_texts, db); next_batch } + // If we have an error, we discard the whole batch Err(err) => { send_error(err, request_ids, db); None @@ -141,16 +171,20 @@ async fn wrap_future( } } +/// Send errors to the Batcher for all `request_ids` fn send_error(error: ClientError, request_ids: Vec, db: &Db) { request_ids.into_iter().for_each(|id| { + // We can `expect` here as the request id should always be in the DB let entry = db.remove(&id).expect("ID not found in db. This is a bug."); // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Err(error.clone())).unwrap_or(()); }); } +/// Send `generated_text` to the Batcher for all `finished` fn send_generated(finished: Vec, db: &Db) { finished.into_iter().for_each(|output| { + // We can `expect` here as the request id should always be in the DB let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); @@ -158,3 +192,18 @@ fn send_generated(finished: Vec, db: &Db) { entry.response_tx.send(Ok(output.output)).unwrap_or(()); }); } + +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), +} + +/// Convert to Axum supported format +impl From for (StatusCode, String) { + fn from(err: InferError) -> Self { + match err { + InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + } + } +} diff --git a/router/src/db.rs b/router/src/db.rs index 03593fc07..9518fa1de 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,16 +1,173 @@ /// This code is massively inspired by Tokio mini-redis -use crate::server::{GenerateParameters, GenerateRequest}; +use crate::{GenerateParameters, GenerateRequest}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; -use parking_lot::RwLock; +use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; +use std::time::Duration; use tokio::sync::oneshot::Sender; +use tokio::time::Instant; +/// Database entry #[derive(Debug)] pub(crate) struct Entry { + /// Request pub request: GenerateRequest, + /// Response sender to communicate between the Batcher and the batching_task pub response_tx: Sender>, + /// Number of tokens in the input pub input_length: usize, + /// Instant when this entry was created + pub time: Instant, +} + +/// Request Database +#[derive(Debug, Clone)] +pub(crate) struct Db { + pub shared: Arc, +} + +/// Shared state +#[derive(Debug)] +pub struct Shared { + state: Mutex, +} + +/// Database State +#[derive(Debug)] +struct State { + /// Database entries organized in a BTreeMap to be able to iterate over them in order + entries: BTreeMap, + + /// Id of the next entry + next_id: u64, + + /// Id of the next batch + next_batch_id: u64, + + /// Start ID of the next batch. Used to iterate inside the entries BTreeMap + next_batch_start_id: u64, +} + +impl State { + /// Get the next requests + fn next_requests( + &self, + max_size: usize, + min_waiting_time: Option, + ) -> Option<(Vec, Vec)> { + // Iterates for max_size over the BTreemap starting from next_batch_start_id + let mut requests = Vec::new(); + let mut ids = Vec::new(); + + for (id, entry) in self + .entries + // Start from next_batch_start_id + .range(self.next_batch_start_id..) + // Take max_size + .take(max_size) + { + if let Some(min_waiting_time) = min_waiting_time { + // Only take entries that waited for at least min_waiting_time + if entry.time.elapsed() < min_waiting_time { + // Since entries are ordered, we already know that all following entries won't + // satisfy the condition + break; + } + } + + requests.push(Request { + id: *id, + inputs: entry.request.inputs.clone(), + input_length: entry.input_length as u32, + parameters: Some(LogitsWarperParameters::from( + entry.request.parameters.clone(), + )), + max_new_tokens: entry.request.parameters.max_new_tokens, + }); + + ids.push(*id); + } + + if requests.is_empty() { + None + } else { + Some((ids, requests)) + } + } +} + +impl Db { + pub(crate) fn new() -> Self { + // Shared state + let shared = Arc::new(Shared { + state: Mutex::new(State { + entries: BTreeMap::new(), + next_id: 0, + next_batch_id: 0, + next_batch_start_id: 0, + }), + }); + + Self { shared } + } + + /// Append an entry to the database + pub(crate) fn append(&self, entry: Entry) { + // Acquire lock + let mut state = self.shared.state.lock(); + + // Insert entry + let id = state.next_id; + state.next_id += 1; + state.entries.insert(id, entry); + } + + /// Remove an entry from the database if it exists + pub(crate) fn remove(&self, id: &u64) -> Option { + let mut state = self.shared.state.lock(); + state.entries.remove(id) + } + + // Get the next batch + pub(crate) fn next_batch( + &self, + min_size: Option, + max_size: usize, + min_waiting_time: Option, + ) -> Option<(Vec, Batch)> { + // Acquire lock + let mut state = self.shared.state.lock(); + + // Get requests from the database + if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) { + if let Some(min_size) = min_size { + // If min_size is set, only return a batch if there are enough requests + if requests.len() < min_size { + return None; + } + } + + // Batch size + let size = requests.len(); + // Longest input length for all requests in batch size + // Used for padding inside the inference server + let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); + let batch = Batch { + id: state.next_batch_id, + requests, + size: size as u32, + max_sequence_length, + }; + // Update next_batch_start_id to the last id in the batch + 1 + state.next_batch_start_id = ids.last().unwrap() + 1; + // Increment batch id + state.next_batch_id += 1; + + return Some((ids, batch)); + } + None + } } impl From for LogitsWarperParameters { @@ -23,129 +180,3 @@ impl From for LogitsWarperParameters { } } } - -#[derive(Debug, Clone)] -pub(crate) struct Db { - pub shared: Arc, -} - -#[derive(Debug)] -pub struct Shared { - state: RwLock, -} - -#[derive(Debug)] -struct State { - entries: BTreeMap, - - /// Identifier to use for the next expiration. Each expiration is associated - /// with a unique identifier. See above for why. - next_id: u64, - - next_batch_id: u64, - - /// Current batch id - next_batch_start_id: u64, -} - -impl Db { - pub(crate) fn new() -> Self { - let shared = Arc::new(Shared { - state: RwLock::new(State { - entries: BTreeMap::new(), - next_id: 0, - next_batch_id: 0, - next_batch_start_id: 0, - }), - }); - - Self { shared } - } - - pub(crate) fn append(&self, entry: Entry) { - let mut state = self.shared.state.write(); - - let id = state.next_id; - state.next_id += 1; - - state.entries.insert(id, entry); - } - - pub(crate) fn remove(&self, id: &u64) -> Option { - let mut state = self.shared.state.write(); - state.entries.remove(id) - } - - pub(crate) fn len(&self) -> usize { - let state = self.shared.state.read(); - state.entries.len() - } - - fn next_requests(&self, max_size: usize) -> Option<(u64, Vec)> { - let state = self.shared.state.read(); - - let requests: Vec = state - .entries - .range(state.next_batch_start_id..) - .take(max_size) - .map(|(id, entry)| Request { - id: *id, - inputs: entry.request.inputs.clone(), - input_length: entry.input_length as u32, - parameters: Some(LogitsWarperParameters::from( - entry.request.parameters.clone(), - )), - max_new_tokens: entry.request.parameters.max_new_tokens, - }) - .collect(); - - if requests.is_empty() { - None - } else { - let last_id = requests.last().unwrap().id; - Some((last_id, requests)) - } - } - - pub(crate) fn next_batch(&self, max_size: usize) -> Option { - if let Some((last_id, requests)) = self.next_requests(max_size) { - let mut state = self.shared.state.write(); - let size = requests.len(); - let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); - let batch = Batch { - id: state.next_batch_id, - requests, - size: size as u32, - max_sequence_length, - }; - state.next_batch_start_id = last_id + 1; - state.next_batch_id += 1; - return Some(batch); - } - None - } - - pub(crate) fn next_batch_minimum_size( - &self, - min_size: usize, - max_size: usize, - ) -> Option { - if let Some((last_id, requests)) = self.next_requests(max_size) { - if requests.len() >= min_size { - let mut state = self.shared.state.write(); - let size = requests.len(); - let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap(); - let batch = Batch { - id: state.next_batch_id, - requests, - size: size as u32, - max_sequence_length, - }; - state.next_batch_start_id = last_id + 1; - state.next_batch_id += 1; - return Some(batch); - } - } - None - } -} diff --git a/router/src/lib.rs b/router/src/lib.rs index 14dc57249..02b912a3e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,8 +1,68 @@ +/// Text Generation Inference Webserver mod batcher; mod db; -mod validation; pub mod server; +mod validation; -use db::{Db, Entry}; use batcher::Batcher; +use db::{Db, Entry}; +use serde::{Deserialize, Serialize}; use validation::Validation; + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct GenerateParameters { + #[serde(default = "default_temperature")] + pub temperature: f32, + #[serde(default = "default_top_k")] + pub top_k: i32, + #[serde(default = "default_top_p")] + pub top_p: f32, + #[serde(default = "default_do_sample")] + pub do_sample: bool, + #[serde(default = "default_max_new_tokens")] + pub max_new_tokens: u32, +} + +fn default_temperature() -> f32 { + 1.0 +} + +fn default_top_k() -> i32 { + 0 +} + +fn default_top_p() -> f32 { + 1.0 +} + +fn default_do_sample() -> bool { + false +} + +fn default_max_new_tokens() -> u32 { + 20 +} + +fn default_parameters() -> GenerateParameters { + GenerateParameters { + temperature: default_temperature(), + top_k: default_top_k(), + top_p: default_top_p(), + do_sample: default_do_sample(), + max_new_tokens: default_max_new_tokens(), + } +} + +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct GenerateRequest { + pub inputs: String, + #[serde(default = "default_parameters")] + pub parameters: GenerateParameters, +} + +#[derive(Serialize)] +pub(crate) struct GeneratedText { + pub generated_text: String, +} + +pub(crate) type GenerateResponse = Vec; diff --git a/router/src/main.rs b/router/src/main.rs index 89cd47313..49051b376 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,37 +1,61 @@ +/// Text Generation Inference webserver entrypoint use bloom_inference_client::ShardedClient; +use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::time::Duration; use text_generation_router::server; use tokenizers::Tokenizer; -use clap::Parser; /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { - #[clap(default_value = "32", long, short, env)] + #[clap(default_value = "128", long, env)] + max_concurrent_requests: usize, + #[clap(default_value = "1000", long, env)] + max_input_length: usize, + #[clap(default_value = "32", long, env)] max_batch_size: usize, + #[clap(default_value = "5", long, env)] + max_waiting_time: u64, #[clap(default_value = "3000", long, short, env)] port: u16, #[clap(default_value = "/tmp/bloom-inference-0", long, env)] - shard_uds_path: String, + master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, + #[clap(default_value = "2", long, env)] + validation_workers: usize, } fn main() -> Result<(), std::io::Error> { // Get args let args = Args::parse(); -// Pattern match configuration + // Pattern match configuration let Args { + max_concurrent_requests, + max_input_length, max_batch_size, + max_waiting_time, port, - shard_uds_path, + master_shard_uds_path, tokenizer_name, + validation_workers, } = args; + if validation_workers == 1 { + panic!("validation_workers must be > 0"); + } + let max_waiting_time = Duration::from_secs(max_waiting_time); + + // Download and instantiate tokenizer + // This will only be used to validate payloads + // + // We need to download it outside of the Tokio runtime let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); + // Launch Tokio runtime tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { tracing_subscriber::fmt::init(); - let sharded_client = ShardedClient::connect_uds(shard_uds_path) + // Instantiate sharded client from the master unix socket + let sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await .expect("Could not connect to server"); + // Clear the cache; useful if the webserver rebooted sharded_client .clear_cache() .await .expect("Unable to clear cache"); tracing::info!("Connected"); + // Binds on localhost let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); - server::run(max_batch_size, sharded_client, tokenizer, addr).await; + // Run server + server::run( + max_concurrent_requests, + max_input_length, + max_batch_size, + max_waiting_time, + sharded_client, + tokenizer, + validation_workers, + addr, + ) + .await; Ok(()) }) } diff --git a/router/src/server.rs b/router/src/server.rs index 0fdfd58b9..422583135 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,68 +1,44 @@ -use crate::{Batcher, Validation}; +use crate::{ + Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation, +}; use axum::extract::Extension; use axum::http::StatusCode; use axum::routing::{get, post}; use axum::{Json, Router}; use bloom_inference_client::ShardedClient; -use serde::Deserialize; use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; use tokenizers::Tokenizer; +use tokio::signal; +use tokio::sync::Semaphore; use tokio::time::Instant; use tracing::instrument; -#[derive(Clone, Debug, Deserialize)] -pub(crate) struct GenerateParameters { - #[serde(default = "default_temperature")] - pub temperature: f32, - #[serde(default = "default_top_k")] - pub top_k: i32, - #[serde(default = "default_top_p")] - pub top_p: f32, - #[serde(default = "default_do_sample")] - pub do_sample: bool, - #[serde(default = "default_max_new_tokens")] - pub max_new_tokens: u32, -} - -fn default_temperature() -> f32 { - 1.0 -} - -fn default_top_k() -> i32 { - 0 -} - -fn default_top_p() -> f32 { - 1.0 -} - -fn default_do_sample() -> bool { - false -} - -fn default_max_new_tokens() -> u32 { - 20 -} - -fn default_parameters() -> GenerateParameters { - GenerateParameters { - temperature: default_temperature(), - top_k: default_top_k(), - top_p: default_top_p(), - do_sample: default_do_sample(), - max_new_tokens: default_max_new_tokens(), - } -} - -#[derive(Clone, Debug, Deserialize)] -pub(crate) struct GenerateRequest { - pub inputs: String, - #[serde(default = "default_parameters")] - pub parameters: GenerateParameters, +// Server shared state +#[derive(Clone)] +struct ServerState { + validation: Validation, + batcher: Batcher, + limit_concurrent_requests: Arc, } +/// Health check method #[instrument(skip(state), fields(time, time_per_token))] -async fn liveness(state: Extension) -> Result<(), (StatusCode, String)> { +async fn health(state: Extension) -> Result<(), (StatusCode, String)> { + // TODO: while this is the best health check we can do, it is a bit on the heavy side and might + // be a bit too slow for a health check. + // What we should do instead if check if the gRPC channels are still healthy. + + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + ( + StatusCode::TOO_MANY_REQUESTS, + "Model is overloaded".to_string(), + ) + })?; + + // Send a small inference request state .batcher .infer( @@ -82,23 +58,35 @@ async fn liveness(state: Extension) -> Result<(), (StatusCode, Stri Ok(()) } +/// Generate method #[instrument(skip(state), fields(time, time_per_token))] async fn generate( state: Extension, req: Json, -) -> Result, (StatusCode, String)> { +) -> Result, (StatusCode, String)> { let start = Instant::now(); + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| { + ( + StatusCode::TOO_MANY_REQUESTS, + "Model is overloaded".to_string(), + ) + })?; + // Validate request let (input_length, validated_request) = state .validation + // FIXME: can't we get rid of the cloning here?? .validate(GenerateRequest { inputs: req.inputs.clone(), parameters: req.parameters.clone(), }) .await?; + // Inference let generated_text = state.batcher.infer(input_length, validated_request).await?; + // Tracing metadata tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record( "time_per_token", @@ -106,31 +94,71 @@ async fn generate( ); tracing::info!("response: {}", generated_text); - Ok(Json(serde_json::json!({ - "generated_text": generated_text, - }))) + // Send response + let response = vec![GeneratedText { generated_text }]; + Ok(Json(response)) } -#[derive(Clone)] -struct ServerState { - validation: Validation, - batcher: Batcher, -} - -pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { - let batcher = Batcher::new(client, max_batch_size); - let validation = Validation::new(tokenizer); - - let shared_state = ServerState { validation, batcher }; +/// Serving method +#[allow(clippy::too_many_arguments)] +pub async fn run( + max_concurrent_requests: usize, + max_input_length: usize, + max_batch_size: usize, + max_waiting_time: Duration, + client: ShardedClient, + tokenizer: Tokenizer, + validation_workers: usize, + addr: SocketAddr, +) { + // Create state + let batcher = Batcher::new(client, max_batch_size, max_waiting_time); + let validation = Validation::new(validation_workers, tokenizer, max_input_length); + let shared_state = ServerState { + validation, + batcher, + limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)), + }; + // Create router let app = Router::new() .route("/generate", post(generate)) .layer(Extension(shared_state.clone())) - .route("/health", get(liveness)) + .route("/health", get(health)) .layer(Extension(shared_state.clone())); + // Run server axum::Server::bind(&addr) .serve(app.into_make_service()) + // Wait until all requests are finished to shut down + .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } + +/// Shutdown signal handler +async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("signal received, starting graceful shutdown"); +} diff --git a/router/src/validation.rs b/router/src/validation.rs index 45b108fd4..49a46b624 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,62 +1,105 @@ -use crate::server::GenerateRequest; +/// Payload validation logic +use crate::GenerateRequest; use axum::http::StatusCode; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; +use tokenizers::{ + DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper, + TokenizerImpl, +}; use tokio::sync::{mpsc, oneshot}; -#[derive(Error, Debug)] -pub enum ValidationError { - #[error("Temperature must be strictly positive")] - Temperature, - #[error("Top p must be <= 0.0 or > 1.0")] - TopP, - #[error("Top k must be strictly positive")] - TopK, - #[error("Max New Tokens must be < 512")] - MaxNewTokens, - #[error("Inputs must have less than 1000 tokens. Given: {0}")] - InputLength(usize), -} - -impl From for (StatusCode, String) { - fn from(err: ValidationError) -> Self { - (StatusCode::BAD_REQUEST, err.to_string()) - } -} - -type ValidationRequest = ( - GenerateRequest, - oneshot::Sender>, -); - +/// Validation #[derive(Debug, Clone)] pub struct Validation { + /// Channel to communicate with the background validation task sender: mpsc::Sender, } impl Validation { - pub(crate) fn new(tokenizer: Tokenizer) -> Self { + pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self { + // Crate channel let (validation_sender, validation_receiver) = mpsc::channel(128); - tokio::spawn(validation_task(tokenizer, validation_receiver)); + // Launch background validation task + tokio::spawn(validation_task( + workers, + tokenizer, + max_input_length, + validation_receiver, + )); Self { sender: validation_sender, } } + /// Validate a payload and get the number of tokens in the input pub(crate) async fn validate( &self, request: GenerateRequest, ) -> Result<(usize, GenerateRequest), ValidationError> { + // Create response channel let (sender, receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here self.sender.send((request, sender)).await.unwrap(); + // Await on response channel + // Unwrap is safe here receiver.await.unwrap() } } -async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver) { - while let Some((request, response_tx)) = receiver.recv().await { +/// Validation task +/// Load balance the validation requests between multiple validation workers +async fn validation_task( + workers: usize, + tokenizer: Tokenizer, + max_input_length: usize, + mut receiver: mpsc::Receiver, +) { + let mut workers_senders = Vec::with_capacity(workers); + + // Create workers + for _ in 0..workers { + let tokenizer_clone = tokenizer.clone(); + // Create channel to communicate with worker + let (worker_sender, worker_receiver) = mpsc::channel(workers); + workers_senders.push(worker_sender); + + // Spawn worker + tokio::task::spawn_blocking(move || { + validation_worker(tokenizer_clone, max_input_length, worker_receiver) + }); + } + + loop { + // Load balance requests between workers + for sender in workers_senders.iter() { + if let Some(validation_request) = receiver.recv().await { + sender.send(validation_request).await.unwrap(); + } else { + return; + } + } + } +} + +/// Check the parameters inside the payload and get the number of tokens inside the input using +/// the tokenizer +fn validation_worker( + tokenizer: TokenizerImpl< + ModelWrapper, + NormalizerWrapper, + PreTokenizerWrapper, + PostProcessorWrapper, + DecoderWrapper, + >, + max_input_length: usize, + mut receiver: mpsc::Receiver, +) { + // Loop over requests + while let Some((request, response_tx)) = receiver.blocking_recv() { if request.parameters.temperature < 0.0 { response_tx .send(Err(ValidationError::Temperature)) @@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver 1000 { + if input_length > max_input_length { response_tx .send(Err(ValidationError::InputLength(input_length))) .unwrap_or(()); @@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver>, +); + +#[derive(Error, Debug)] +pub enum ValidationError { + #[error("Temperature must be strictly positive")] + Temperature, + #[error("Top p must be <= 0.0 or > 1.0")] + TopP, + #[error("Top k must be strictly positive")] + TopK, + #[error("Max New Tokens must be < 512")] + MaxNewTokens, + #[error("Inputs must have less than 1000 tokens. Given: {0}")] + InputLength(usize), +} + +impl From for (StatusCode, String) { + fn from(err: ValidationError) -> Self { + (StatusCode::BAD_REQUEST, err.to_string()) + } +} diff --git a/run.sh b/run.sh deleted file mode 100644 index 303035017..000000000 --- a/run.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH" - -# Run in background -$server_cmd 2>&1 > /dev/null & - -# Check if server is running by checking if the unix socket is created -FILE=/tmp/bloom-inference-0 -while : - do - if test -S "$FILE"; then - echo "Text Generation Python gRPC server started" - break - else - echo "Waiting for Text Generation Python gRPC server to start" - sleep 5 - fi - done - -sleep 1 - -# Run in background -text-generation-router & - -# Wait for any process to exit -wait -n - -# Exit with status of process that exited first -exit $? \ No newline at end of file diff --git a/router/rust-toolchain.toml b/rust-toolchain.toml similarity index 100% rename from router/rust-toolchain.toml rename to rust-toolchain.toml diff --git a/server/.gitignore b/server/.gitignore new file mode 100644 index 000000000..0ebf3670b --- /dev/null +++ b/server/.gitignore @@ -0,0 +1,155 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +bloom_inference/__pycache__/ +bloom_inference/pb/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/server/Makefile b/server/Makefile index d959dbb38..52b4d4059 100644 --- a/server/Makefile +++ b/server/Makefile @@ -4,17 +4,28 @@ gen-server: find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch bloom_inference/pb/__init__.py -unit-tests: - python -m pytest --cov=bloom_inference tests +install-transformers: + # Install specific version of transformers + rm transformers || true + wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip + unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip + rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip + mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers + cd transformers && python setup.py install -unit-tests-reporting: - python -m pytest --junitxml=report.xml --cov=bloom_inference tests +install-torch: + # Install specific version of torch + pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir pip-install: pip install grpcio-tools make gen-server + make install-torch + make install-transformers pip install . install: poetry install - make gen-server \ No newline at end of file + make gen-server + make install-torch + make install-transformers diff --git a/server/bloom_inference/cli.py b/server/bloom_inference/cli.py index a5f84e77f..751485bb0 100644 --- a/server/bloom_inference/cli.py +++ b/server/bloom_inference/cli.py @@ -1,41 +1,51 @@ +import os import typer from pathlib import Path -from torch.distributed.launcher import launch_agent, LaunchConfig from typing import Optional -from bloom_inference import server +from bloom_inference import prepare_weights, server app = typer.Typer() @app.command() -def launcher( - model_name: str, - num_gpus: int = 1, - shard_directory: Optional[Path] = None, +def serve( + model_name: str, + sharded: bool = False, + shard_directory: Optional[Path] = None, + uds_path: Path = "/tmp/bloom-inference", ): - if num_gpus == 1: - serve(model_name, False, shard_directory) + if sharded: + assert ( + shard_directory is not None + ), "shard_directory must be set when sharded is True" + assert ( + os.getenv("RANK", None) is not None + ), "RANK must be set when sharded is True" + assert ( + os.getenv("WORLD_SIZE", None) is not None + ), "WORLD_SIZE must be set when sharded is True" + assert ( + os.getenv("MASTER_ADDR", None) is not None + ), "MASTER_ADDR must be set when sharded is True" + assert ( + os.getenv("MASTER_PORT", None) is not None + ), "MASTER_PORT must be set when sharded is True" - else: - config = LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=num_gpus, - rdzv_backend="c10d", - max_restarts=0, - ) - launch_agent(config, server.serve, [model_name, True, shard_directory]) + server.serve(model_name, sharded, uds_path, shard_directory) @app.command() -def serve( - model_name: str, - sharded: bool = False, - shard_directory: Optional[Path] = None, +def prepare_weights( + model_name: str, + shard_directory: Path, + cache_directory: Path, + num_shard: int = 1, ): - server.serve(model_name, sharded, shard_directory) + prepare_weights.prepare_weights( + model_name, cache_directory, shard_directory, num_shard + ) if __name__ == "__main__": diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index 8b0e7ab0b..0ba90cee3 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -24,6 +24,7 @@ torch.manual_seed(0) class Batch: batch_id: int requests: List[generate_pb2.Request] + all_input_lengths: List[int] input_ids: Dict[str, torch.Tensor] all_input_ids: List[torch.Tensor] next_token_choosers: List[NextTokenChooser] @@ -46,12 +47,12 @@ class Batch: inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] + all_input_lengths = [] # Parse batch for r in pb.requests: inputs.append(r.inputs) - input_lengths.append(r.input_length) + all_input_lengths.append(r.input_length) next_token_choosers.append( NextTokenChooser( temperature=r.parameters.temperature, @@ -63,17 +64,12 @@ class Batch: stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) - # Remove padding from all_input_ids - all_input_ids = [ - input_ids.squeeze(0)[-length:].unsqueeze(-1) - for length, input_ids in zip( - input_lengths, input_ids["input_ids"].split(1, dim=0) - ) - ] + all_input_ids = input_ids["input_ids"].unsqueeze(-1) return cls( batch_id=pb.id, requests=pb.requests, + all_input_lengths=all_input_lengths, input_ids=input_ids, all_input_ids=all_input_ids, next_token_choosers=next_token_choosers, @@ -91,6 +87,7 @@ class Batch: # Batch attributes input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} requests = [] + all_input_lengths = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -100,6 +97,7 @@ class Batch: start_index = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) + all_input_lengths.extend(batch.all_input_lengths) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -198,6 +196,7 @@ class Batch: return cls( batch_id=batches[0].batch_id, requests=requests, + all_input_lengths=all_input_lengths, input_ids=input_ids, all_input_ids=all_input_ids, next_token_choosers=next_token_choosers, @@ -227,7 +226,10 @@ class BLOOM: self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.model = ( - AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype) + AutoModelForCausalLM.from_pretrained(model_name) + .eval() + .to(self.device) + .to(dtype) ) self.num_heads = self.model.base_model.num_heads @@ -253,6 +255,7 @@ class BLOOM: # New input_ids for next forward next_batch_input_ids = [] next_batch_all_input_ids = [] + next_all_input_lengths = [] next_batch_size = 0 next_batch_max_sequence_length = 0 @@ -263,6 +266,7 @@ class BLOOM: # Zipped iterator iterator = zip( batch.requests, + batch.all_input_lengths, outputs.logits, batch.next_token_choosers, batch.stopping_criterias, @@ -272,6 +276,7 @@ class BLOOM: # For each member of the batch for i, ( request, + input_length, logits, next_token_chooser, stopping_criteria, @@ -302,8 +307,10 @@ class BLOOM: next_batch_input_ids.append(next_token) next_batch_all_input_ids.append(all_tokens) next_batch_size += 1 + new_input_length = input_length + 1 + next_all_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( - next_batch_max_sequence_length, len(all_tokens) + next_batch_max_sequence_length, new_input_length ) # We finished all generations in the batch; there is no next batch @@ -350,6 +357,7 @@ class BLOOM: next_batch = Batch( batch_id=batch.batch_id, requests=next_batch_requests, + all_input_lengths=next_all_input_lengths, input_ids=next_batch_input_ids, all_input_ids=next_batch_all_input_ids, next_token_choosers=next_batch_next_token_choosers, @@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM): if self.master: # TODO @thomasw21 do some caching shard_state_dict_paths = prepare_weights( - model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size + model_name, + shard_directory / "cache", + shard_directory, + tp_world_size=self.world_size, ) shard_state_dict_paths = [ str(path.absolute()) for path in shard_state_dict_paths @@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM): use_cache=True, ) + # Logits are sharded, so we need to gather them logits_shard = outputs.logits[:, -1, :].contiguous() batch_size, vocab_shard_size = logits_shard.shape diff --git a/server/bloom_inference/pb/.gitignore b/server/bloom_inference/pb/.gitignore index a9feac816..8527ad136 100644 --- a/server/bloom_inference/pb/.gitignore +++ b/server/bloom_inference/pb/.gitignore @@ -1,2 +1,2 @@ *.py -*.py-e +*.py-e \ No newline at end of file diff --git a/server/bloom_inference/prepare_weights.py b/server/bloom_inference/prepare_weights.py index 7cf3dbb5c..5594998d0 100644 --- a/server/bloom_inference/prepare_weights.py +++ b/server/bloom_inference/prepare_weights.py @@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status def match_suffix(text, suffix): - return text[-len(suffix):] == suffix + return text[-len(suffix) :] == suffix def http_get( - url: str, - temp_file: BinaryIO, - *, - timeout=10.0, - max_retries=0, + url: str, + temp_file: BinaryIO, + *, + timeout=10.0, + max_retries=0, ): """ Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. @@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path): return filename -def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int): +def prepare_weights( + model_name: str, cache_path: Path, save_path: Path, tp_world_size: int +): save_paths = [ save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty" for tp_rank in range(tp_world_size) @@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world if model_name == "bigscience/bloom-560m": url = hf_hub_url(model_name, filename="pytorch_model.bin") cache_download_url(url, cache_path) + elif model_name == "bigscience/bloom": url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json") index_path = cache_download_url(url, cache_path) @@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world index = json.load(f) # Get unique file names - weight_files = list(set([filename for filename in index["weight_map"].values()])) + weight_files = list( + set([filename for filename in index["weight_map"].values()]) + ) urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files] - Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)) + Parallel(n_jobs=5)( + delayed(cache_download_url)(url, cache_path) for url in tqdm(urls) + ) else: raise ValueError(f"Unknown model name: {model_name}") @@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world for state_name in keys: state = state_dict[state_name] if any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.query_key_value.weight", - "self_attention.query_key_value.bias", - "mlp.dense_h_to_4h.weight", - "mlp.dense_h_to_4h.bias", - "word_embeddings.weight", - ] + match_suffix(state_name, candidate) + for candidate in [ + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "word_embeddings.weight", + ] ): output_size = state.shape[0] assert output_size % tp_world_size == 0 @@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world assert len(sharded_weights) == tp_world_size for tp_rank, shard in enumerate(sharded_weights): - shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() + shards_state_dicts[tp_rank][ + "transformer." + state_name + ] = shard.detach().clone() elif match_suffix(state_name, "lm_head.weight"): output_size = state.shape[0] @@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world shards_state_dicts[tp_rank][state_name] = shard.detach().clone() elif any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.dense.weight", - "mlp.dense_4h_to_h.weight", - ] + match_suffix(state_name, candidate) + for candidate in [ + "self_attention.dense.weight", + "mlp.dense_4h_to_h.weight", + ] ): input_size = state.shape[1] assert input_size % tp_world_size == 0 @@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world sharded_weights = torch.split(state, block_size, dim=1) assert len(sharded_weights) == tp_world_size for tp_rank, shard in enumerate(sharded_weights): - shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() + shards_state_dicts[tp_rank][ + "transformer." + state_name + ] = shard.detach().clone() elif any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.dense.bias", - "mlp.dense_4h_to_h.bias", - ] + match_suffix(state_name, candidate) + for candidate in [ + "self_attention.dense.bias", + "mlp.dense_4h_to_h.bias", + ] ): - shards_state_dicts[0]["transformer." + state_name] = state.detach().clone() + shards_state_dicts[0][ + "transformer." + state_name + ] = state.detach().clone() for tp_rank in range(1, tp_world_size): - shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state) + shards_state_dicts[tp_rank][ + "transformer." + state_name + ] = torch.zeros_like(state) else: # We duplicate parameters across tp ranks for tp_rank in range(tp_world_size): - shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone() + shards_state_dicts[tp_rank][ + "transformer." + state_name + ] = state.detach().clone() del state_dict[state_name] # delete key from state_dict del state # delete tensor @@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world # we save state_dict for tp_rank, (save_path, shard_state_dict) in enumerate( - zip(save_paths, shards_state_dicts) + zip(save_paths, shards_state_dicts) ): save_paths.append(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) @@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world torch.save(shard_state_dict, save_path) return save_paths - - -if __name__ == "__main__": - from argparse import ArgumentParser - - parser = ArgumentParser() - - parser.add_argument("--model-name", required=True, type=str) - parser.add_argument("--cache-path", required=True, type=str) - parser.add_argument("--save-path", required=True, type=str) - parser.add_argument("--world-size", required=True, type=int) - args = parser.parse_args() - - prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size) diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index e89706e0a..734aba439 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch=next_batch.to_pb() if next_batch else None, ) - async def GenerateUntilFinished(self, request, context): - batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device) - generated_texts = [] - while not generated_texts: - generated_texts, next_batch = self.model.generate_token(batch) - batch = next_batch - self.cache.set(next_batch) - - return generate_pb2.GenerateUntilFinishedResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts - ], - batch=next_batch.to_pb() if next_batch else None, - ) - - async def GenerateUntilFinishedWithCache(self, request, context): - if len(request.batches) == 0: - raise ValueError("Must provide at least one batch") - - batches = [] - for batch_pb in request.batches: - batch = self.cache.pop(batch_pb.id) - if batch is None: - raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") - batches.append(batch) - - if len(batches) > 1: - batch = Batch.concatenate(batches) - else: - batch = batches[0] - - generated_texts = [] - while not generated_texts: - generated_texts, next_batch = self.model.generate_token(batch) - batch = next_batch - self.cache.set(next_batch) - - return generate_pb2.GenerateUntilFinishedWithCacheResponse( - generated_texts=[ - generated_text.to_pb() for generated_text in generated_texts - ], - batch=next_batch.to_pb() if next_batch else None, - ) - - -def serve(model_name, sharded, shard_directory): +def serve( + model_name: str, + sharded: bool, + uds_path: Path, + shard_directory: Optional[Path] = None, +): async def serve_inner( model_name: str, sharded: bool = False, shard_directory: Optional[Path] = None, ): - unix_socket_template = "unix:///tmp/bloom-inference-{}" + unix_socket_template = "unix://{}-{}" if sharded: if shard_directory is None: raise ValueError("shard_directory must be set when sharded is True") model = BLOOMSharded(model_name, shard_directory) server_urls = [ - unix_socket_template.format(rank) for rank in range(model.world_size) + unix_socket_template.format(uds_path, rank) + for rank in range(model.world_size) ] - local_url = unix_socket_template.format(model.rank) + local_url = server_urls[model.rank] else: model = BLOOM(model_name) - local_url = unix_socket_template.format(0) + local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] server = aio.server() @@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory): server.add_insecure_port(local_url) await server.start() print("Server started at {}".format(local_url)) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except KeyboardInterrupt: + print("Signal received. Shutting down") + await server.stop(0) asyncio.run(serve_inner(model_name, sharded, shard_directory)) diff --git a/server/bloom_inference/utils.py b/server/bloom_inference/utils.py index fe2c913e8..c351806ab 100644 --- a/server/bloom_inference/utils.py +++ b/server/bloom_inference/utils.py @@ -82,7 +82,6 @@ def initialize_torch_distributed(): world_size=world_size, rank=rank, timeout=timedelta(seconds=60), - init_method="tcp://localhost:6000", ) return torch.distributed.distributed_c10d._get_default_group(), rank, world_size diff --git a/server/poetry.lock b/server/poetry.lock index 8100c2008..3feee60dc 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -205,7 +205,7 @@ python-versions = ">=3.7" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2" +content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496" [metadata.files] accelerate = [ diff --git a/server/pyproject.toml b/server/pyproject.toml index 1dd8ae277..3d38f512d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app' python = "^3.9" protobuf = "^4.21.7" grpcio = "^1.49.1" -torch = "^1.12.1" typer = "^0.6.1" grpcio-reflection = "^1.49.1" accelerate = "^0.12.0"