diff --git a/.dockerignore b/.dockerignore
index 38e8f8243..5aa1aa3a4 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,2 +1,2 @@
aml
-router/target
\ No newline at end of file
+target
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 723ef36f4..ec376bb82 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
-.idea
\ No newline at end of file
+.idea
+target
\ No newline at end of file
diff --git a/router/Cargo.lock b/Cargo.lock
similarity index 91%
rename from router/Cargo.lock
rename to Cargo.lock
index eda429087..551f7aebc 100644
--- a/router/Cargo.lock
+++ b/Cargo.lock
@@ -55,9 +55,9 @@ dependencies = [
[[package]]
name = "async-trait"
-version = "0.1.57"
+version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
+checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c"
dependencies = [
"proc-macro2",
"quote",
@@ -166,9 +166,9 @@ dependencies = [
[[package]]
name = "bumpalo"
-version = "3.11.0"
+version = "3.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
+checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
[[package]]
name = "byteorder"
@@ -255,9 +255,9 @@ dependencies = [
[[package]]
name = "clap"
-version = "4.0.15"
+version = "4.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
+checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
dependencies = [
"atty",
"bitflags",
@@ -391,6 +391,16 @@ dependencies = [
"typenum",
]
+[[package]]
+name = "ctrlc"
+version = "3.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173"
+dependencies = [
+ "nix",
+ "winapi",
+]
+
[[package]]
name = "darling"
version = "0.10.2"
@@ -529,7 +539,7 @@ dependencies = [
"cfg-if",
"libc",
"redox_syscall",
- "windows-sys",
+ "windows-sys 0.36.1",
]
[[package]]
@@ -936,9 +946,9 @@ dependencies = [
[[package]]
name = "itoa"
-version = "1.0.3"
+version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754"
+checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
[[package]]
name = "js-sys"
@@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
-version = "0.2.134"
+version = "0.2.135"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
+checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
[[package]]
name = "lock_api"
@@ -1047,7 +1057,7 @@ dependencies = [
"libc",
"log",
"wasi 0.11.0+wasi-snapshot-preview1",
- "windows-sys",
+ "windows-sys 0.36.1",
]
[[package]]
@@ -1074,6 +1084,18 @@ dependencies = [
"tempfile",
]
+[[package]]
+name = "nix"
+version = "0.25.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb"
+dependencies = [
+ "autocfg",
+ "bitflags",
+ "cfg-if",
+ "libc",
+]
+
[[package]]
name = "nom"
version = "7.1.1"
@@ -1084,6 +1106,16 @@ dependencies = [
"minimal-lexical",
]
+[[package]]
+name = "nu-ansi-term"
+version = "0.46.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
+dependencies = [
+ "overload",
+ "winapi",
+]
+
[[package]]
name = "num_cpus"
version = "1.13.1"
@@ -1185,6 +1217,12 @@ version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
+[[package]]
+name = "overload"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
+
[[package]]
name = "parking_lot"
version = "0.12.1"
@@ -1197,15 +1235,15 @@ dependencies = [
[[package]]
name = "parking_lot_core"
-version = "0.9.3"
+version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
+checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
- "windows-sys",
+ "windows-sys 0.42.0",
]
[[package]]
@@ -1300,9 +1338,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.46"
+version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
+checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
dependencies = [
"unicode-ident",
]
@@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
dependencies = [
"lazy_static",
- "windows-sys",
+ "windows-sys 0.36.1",
]
[[package]]
@@ -1584,9 +1622,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.85"
+version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44"
+checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
dependencies = [
"itoa",
"ryu",
@@ -1625,6 +1663,15 @@ dependencies = [
"lazy_static",
]
+[[package]]
+name = "signal-hook-registry"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
+dependencies = [
+ "libc",
+]
+
[[package]]
name = "slab"
version = "0.4.7"
@@ -1681,10 +1728,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
-name = "syn"
-version = "1.0.101"
+name = "subprocess"
+version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2"
+checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
+dependencies = [
+ "libc",
+ "winapi",
+]
+
+[[package]]
+name = "syn"
+version = "1.0.102"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
dependencies = [
"proc-macro2",
"quote",
@@ -1741,13 +1798,24 @@ dependencies = [
"winapi",
]
+[[package]]
+name = "text-generation-launcher"
+version = "0.1.0"
+dependencies = [
+ "clap 4.0.17",
+ "ctrlc",
+ "subprocess",
+ "tracing",
+ "tracing-subscriber",
+]
+
[[package]]
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
- "clap 4.0.15",
+ "clap 4.0.17",
"futures",
"parking_lot",
"serde",
@@ -1872,6 +1940,7 @@ dependencies = [
"num_cpus",
"parking_lot",
"pin-project-lite",
+ "signal-hook-registry",
"socket2",
"tokio-macros",
"winapi",
@@ -1910,9 +1979,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
-version = "0.1.10"
+version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af"
+checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [
"futures-core",
"pin-project-lite",
@@ -2031,9 +2100,9 @@ dependencies = [
[[package]]
name = "tower-layer"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
+checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-service"
@@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
-version = "0.1.36"
+version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307"
+checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"log",
@@ -2056,9 +2125,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
-version = "0.1.22"
+version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2"
+checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
@@ -2067,9 +2136,9 @@ dependencies = [
[[package]]
name = "tracing-core"
-version = "0.1.29"
+version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7"
+checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
dependencies = [
"once_cell",
"valuable",
@@ -2108,11 +2177,11 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
-version = "0.3.15"
+version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b"
+checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [
- "ansi_term",
+ "nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
@@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
[[package]]
name = "unicode-ident"
-version = "1.0.4"
+version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd"
+checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
[[package]]
name = "unicode-normalization"
@@ -2361,43 +2430,100 @@ version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
dependencies = [
- "windows_aarch64_msvc",
- "windows_i686_gnu",
- "windows_i686_msvc",
- "windows_x86_64_gnu",
- "windows_x86_64_msvc",
+ "windows_aarch64_msvc 0.36.1",
+ "windows_i686_gnu 0.36.1",
+ "windows_i686_msvc 0.36.1",
+ "windows_x86_64_gnu 0.36.1",
+ "windows_x86_64_msvc 0.36.1",
]
+[[package]]
+name = "windows-sys"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
+dependencies = [
+ "windows_aarch64_gnullvm",
+ "windows_aarch64_msvc 0.42.0",
+ "windows_i686_gnu 0.42.0",
+ "windows_i686_msvc 0.42.0",
+ "windows_x86_64_gnu 0.42.0",
+ "windows_x86_64_gnullvm",
+ "windows_x86_64_msvc 0.42.0",
+]
+
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
+
[[package]]
name = "windows_aarch64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
+
[[package]]
name = "windows_i686_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
+[[package]]
+name = "windows_i686_gnu"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
+
[[package]]
name = "windows_i686_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
+[[package]]
+name = "windows_i686_msvc"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
+
[[package]]
name = "windows_x86_64_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
+
[[package]]
name = "windows_x86_64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
+
[[package]]
name = "winreg"
version = "0.10.1"
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 000000000..d3f7dfb09
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,11 @@
+[workspace]
+members = [
+ "router",
+ "router/client",
+ "launcher"
+]
+
+[profile.release]
+debug = 1
+incremental = true
+lto = "off"
diff --git a/Dockerfile b/Dockerfile
index 9b8a20549..70b91bc21 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM rust:1.64 as builder
+FROM rust:1.64 as router-builder
WORKDIR /usr/src
@@ -9,7 +9,17 @@ WORKDIR /usr/src/router
RUN cargo install --path .
-FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
+FROM rust:1.64 as launcher-builder
+
+WORKDIR /usr/src
+
+COPY launcher launcher
+
+WORKDIR /usr/src/launcher
+
+RUN cargo install --path .
+
+FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
@@ -34,17 +44,15 @@ RUN cd ~ && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y
+WORKDIR /usr/src
+
+COPY server/Makefile server/Makefile
+
# Install specific version of torch
-RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
+RUN cd server && make install-torch
# Install specific version of transformers
-RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
- unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
- rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
- cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
- /opt/miniconda/envs/text-generation/bin/python setup.py install
-
-WORKDIR /usr/src
+RUN cd server && make install-transformers
# Install server
COPY server server
@@ -52,9 +60,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router
-COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
+COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
+COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
-COPY run.sh .
-RUN chmod +x run.sh
-
-CMD ["./run.sh"]
\ No newline at end of file
+CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH
\ No newline at end of file
diff --git a/Makefile b/Makefile
new file mode 100644
index 000000000..3a80a12db
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,19 @@
+install-server:
+ cd server && make pip-install
+
+install-router:
+ cd router && cargo install --path .
+
+install-launcher:
+ cd launcher && cargo install --path .
+
+install:
+ make install-server
+ make install-router
+ make install-launcher
+
+run-bloom-560m:
+ text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2
+
+run-bloom:
+ text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8
diff --git a/README.md b/README.md
index 6d23d9c5b..18e9865de 100644
--- a/README.md
+++ b/README.md
@@ -1,50 +1,51 @@
-# Text Generation Inference
+# LLM Text Generation Inference
-A Rust and gRPC server for text generation inference.
+
-## Load Tests
+
+
+
+
+A Rust and gRPC server for large language models text generation inference.
+
+## Load Tests for BLOOM
See `k6/load_test.js`
-We send the default examples with a 1 second delay between each request.
+We send the default examples with a 1 second delay between requests.
Stages:
-- Ramp up to 50 concurrent requests per second in 1min
-- Ramp up from 50 to 100 concurrent requests per second in 2min
-- Ramp down to 0 concurrent requests per second in 1min
+- Ramp up to 50 vus in 1min
+- Ramp up from 50 to 100 vus in 2min
+- Ramp down to 0 vus in 1min
-| | avg | min | med | max | p(90) | p(95) | RPS |
-|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------|
-| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
-| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
-| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
+| | avg | min | med | max | p(90) | p(95) | RPS |
+|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
+| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
+| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
+| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
## Install
```shell
-cd server
-pip install .
+make install
```
-```
-cd router
-cargo build --release
-```
-
-## Run
+## Run
```shell
-python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models
+make run-bloom-560m
```
+## Test
+
```shell
-./router/target/release/router
+curl 127.0.0.1:3000/generate \
+ -X POST \
+ -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
+ -H 'Content-Type: application/json'
```
## TODO:
-- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
-- [ ] Add tests
-- [ ] Add shutdown logic in router and server
-- [ ] Improve multi-processing logic in server
-- [ ] Improve past key layer indexing?
\ No newline at end of file
+- [ ] Add tests for the `server/model` logic
\ No newline at end of file
diff --git a/aml/deployment.yaml b/aml/deployment.yaml
index be28ceef2..35d19006f 100644
--- a/aml/deployment.yaml
+++ b/aml/deployment.yaml
@@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME: bigscience/bloom
NUM_GPUS: 8
environment:
- image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
+ image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2
inference_config:
liveness_route:
port: 3000
diff --git a/assets/architecture.jpg b/assets/architecture.jpg
new file mode 100644
index 000000000..e0a5f7c6f
Binary files /dev/null and b/assets/architecture.jpg differ
diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml
new file mode 100644
index 000000000..ae3558b1c
--- /dev/null
+++ b/launcher/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "text-generation-launcher"
+version = "0.1.0"
+edition = "2021"
+authors = ["Olivier Dehaene"]
+description = "Text Generation Launcher"
+
+[dependencies]
+clap = { version = "4.0.15", features = ["derive", "env"] }
+ctrlc = "3.2.3"
+subprocess = "0.2.9"
+tracing = "0.1.37"
+tracing-subscriber = "0.3.16"
diff --git a/launcher/src/main.rs b/launcher/src/main.rs
new file mode 100644
index 000000000..0821a456a
--- /dev/null
+++ b/launcher/src/main.rs
@@ -0,0 +1,358 @@
+use clap::Parser;
+use std::io::{BufRead, BufReader, Read};
+use std::path::Path;
+use std::process::ExitCode;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::mpsc::TryRecvError;
+use std::sync::Arc;
+use std::sync::{mpsc, Mutex};
+use std::thread;
+use std::thread::sleep;
+use std::time::{Duration, Instant};
+use std::{fs, io};
+use subprocess::{Popen, PopenConfig, PopenError, Redirection};
+
+/// App Configuration
+#[derive(Parser, Debug)]
+#[clap(author, version, about, long_about = None)]
+struct Args {
+ #[clap(default_value = "bigscience/bloom-560m", long, env)]
+ model_name: String,
+ #[clap(long, env)]
+ num_shard: Option,
+ #[clap(long, env)]
+ shard_directory: Option,
+ #[clap(default_value = "128", long, env)]
+ max_concurrent_requests: usize,
+ #[clap(default_value = "1000", long, env)]
+ max_input_length: usize,
+ #[clap(default_value = "32", long, env)]
+ max_batch_size: usize,
+ #[clap(default_value = "5", long, env)]
+ max_waiting_time: u64,
+ #[clap(default_value = "3000", long, short, env)]
+ port: u16,
+ #[clap(default_value = "/tmp/text-generation-server", long, env)]
+ shard_uds_path: String,
+ #[clap(default_value = "localhost", long, env)]
+ master_addr: String,
+ #[clap(default_value = "29500", long, env)]
+ master_port: usize,
+}
+
+fn main() -> ExitCode {
+ tracing_subscriber::fmt::init();
+
+ // Pattern match configuration
+ let Args {
+ model_name,
+ num_shard,
+ shard_directory,
+ max_concurrent_requests,
+ max_input_length,
+ max_batch_size,
+ max_waiting_time,
+ port,
+ shard_uds_path,
+ master_addr,
+ master_port,
+ } = Args::parse();
+
+ // By default we only have one master shard
+ let num_shard = num_shard.unwrap_or(1);
+
+ // Signal handler
+ let running = Arc::new(AtomicBool::new(true));
+ let r = running.clone();
+ ctrlc::set_handler(move || {
+ r.store(false, Ordering::SeqCst);
+ })
+ .expect("Error setting Ctrl-C handler");
+
+ // Shared shutdown bool
+ let shutdown = Arc::new(Mutex::new(false));
+ // Shared shutdown channel
+ // When shutting down, the main thread will wait for all senders to be dropped
+ let (shutdown_sender, shutdown_receiver) = mpsc::channel();
+
+ // Shared channel to track shard status
+ let (status_sender, status_receiver) = mpsc::channel();
+
+ // Start shard processes
+ for rank in 0..num_shard {
+ let model_name = model_name.clone();
+ let uds_path = shard_uds_path.clone();
+ let shard_directory = shard_directory.clone();
+ let master_addr = master_addr.clone();
+ let status_sender = status_sender.clone();
+ let shutdown = shutdown.clone();
+ let shutdown_sender = shutdown_sender.clone();
+ thread::spawn(move || {
+ shard_manager(
+ model_name,
+ uds_path,
+ shard_directory,
+ rank,
+ num_shard,
+ master_addr,
+ master_port,
+ status_sender,
+ shutdown,
+ shutdown_sender,
+ )
+ });
+ }
+ drop(shutdown_sender);
+
+ // Wait for shard to start
+ let mut shard_ready = 0;
+ while running.load(Ordering::SeqCst) {
+ match status_receiver.try_recv() {
+ Ok(ShardStatus::Ready) => {
+ shard_ready += 1;
+ if shard_ready == num_shard {
+ break;
+ }
+ }
+ Err(TryRecvError::Empty) => {
+ sleep(Duration::from_millis(100));
+ }
+ Ok(ShardStatus::Failed((rank, err))) => {
+ tracing::error!("Shard {} failed to start:\n{}", rank, err);
+ shutdown_shards(shutdown, &shutdown_receiver);
+ return ExitCode::FAILURE;
+ }
+ Err(TryRecvError::Disconnected) => {
+ tracing::error!("Shard status channel disconnected");
+ shutdown_shards(shutdown, &shutdown_receiver);
+ return ExitCode::FAILURE;
+ }
+ }
+ }
+
+ // We might have received a termination signal
+ if !running.load(Ordering::SeqCst) {
+ shutdown_shards(shutdown, &shutdown_receiver);
+ return ExitCode::SUCCESS;
+ }
+
+ // All shard started
+ // Start webserver
+ tracing::info!("Starting Webserver");
+ let mut webserver = match Popen::create(
+ &[
+ "text-generation-router",
+ "--max-concurrent-requests",
+ &max_concurrent_requests.to_string(),
+ "--max-input-length",
+ &max_input_length.to_string(),
+ "--max-batch-size",
+ &max_batch_size.to_string(),
+ "--max-waiting-time",
+ &max_waiting_time.to_string(),
+ "--port",
+ &port.to_string(),
+ "--master-shard-uds-path",
+ &format!("{}-0", shard_uds_path),
+ "--tokenizer-name",
+ &model_name,
+ ],
+ PopenConfig {
+ stdout: Redirection::Pipe,
+ stderr: Redirection::Pipe,
+ // Needed for the shutdown procedure
+ setpgid: true,
+ ..Default::default()
+ },
+ ) {
+ Ok(p) => p,
+ Err(err) => {
+ tracing::error!("Failed to start webserver: {}", err);
+ if let PopenError::IoError(err) = err {
+ if err.kind() == io::ErrorKind::NotFound {
+ tracing::error!("text-generation-router not found in PATH");
+ tracing::error!("Please install it with `make install-router`")
+ }
+ }
+
+ shutdown_shards(shutdown, &shutdown_receiver);
+ return ExitCode::FAILURE;
+ }
+ };
+
+ // Redirect STDOUT and STDERR to the console
+ let webserver_stdout = webserver.stdout.take().unwrap();
+ let webserver_stderr = webserver.stderr.take().unwrap();
+
+ thread::spawn(move || {
+ let stdout = BufReader::new(webserver_stdout);
+ let stderr = BufReader::new(webserver_stderr);
+ for line in stdout.lines() {
+ println!("{}", line.unwrap());
+ }
+ for line in stderr.lines() {
+ println!("{}", line.unwrap());
+ }
+ });
+
+ // Default exit code
+ let mut exit_code = ExitCode::SUCCESS;
+
+ while running.load(Ordering::SeqCst) {
+ if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
+ tracing::error!("Shard {} failed:\n{}", rank, err);
+ exit_code = ExitCode::FAILURE;
+ break;
+ };
+
+ match webserver.poll() {
+ Some(_) => {
+ tracing::error!("Webserver Crashed");
+ shutdown_shards(shutdown, &shutdown_receiver);
+ return ExitCode::FAILURE;
+ }
+ None => {
+ sleep(Duration::from_millis(100));
+ }
+ };
+ }
+
+ // Graceful termination
+ webserver.terminate().unwrap();
+ tracing::info!("Waiting for webserver to gracefully shutdown");
+ webserver.wait_timeout(Duration::from_secs(90)).unwrap();
+ tracing::info!("Webserver terminated");
+ shutdown_shards(shutdown, &shutdown_receiver);
+
+ exit_code
+}
+
+#[derive(Debug)]
+enum ShardStatus {
+ Ready,
+ Failed((usize, String)),
+}
+
+#[allow(clippy::too_many_arguments)]
+fn shard_manager(
+ model_name: String,
+ uds_path: String,
+ shard_directory: Option,
+ rank: usize,
+ world_size: usize,
+ master_addr: String,
+ master_port: usize,
+ status_sender: mpsc::Sender,
+ shutdown: Arc>,
+ _shutdown_sender: mpsc::Sender<()>,
+) {
+ // Get UDS path
+ let uds_string = format!("{}-{}", uds_path, rank);
+ let uds = Path::new(&uds_string);
+ // Clean previous runs
+ fs::remove_file(uds).unwrap_or_default();
+
+ // Process args
+ let mut shard_argv = vec![
+ "bloom-inference-server".to_string(),
+ "serve".to_string(),
+ model_name,
+ "--uds-path".to_string(),
+ uds_path,
+ ];
+
+ if world_size > 1 {
+ shard_argv.push("--sharded".to_string());
+ }
+
+ if let Some(shard_directory) = shard_directory {
+ shard_argv.push("--shard-directory".to_string());
+ shard_argv.push(shard_directory);
+ }
+
+ // Start process
+ tracing::info!("Starting shard {}", rank);
+ let mut p = match Popen::create(
+ &shard_argv,
+ PopenConfig {
+ stdout: Redirection::Pipe,
+ stderr: Redirection::Pipe,
+ // Needed for the shutdown procedure
+ setpgid: true,
+ // NCCL env vars
+ env: Some(vec![
+ ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
+ (
+ "WORLD_SIZE".parse().unwrap(),
+ world_size.to_string().parse().unwrap(),
+ ),
+ ("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
+ (
+ "MASTER_PORT".parse().unwrap(),
+ master_port.to_string().parse().unwrap(),
+ ),
+ ]),
+ ..Default::default()
+ },
+ ) {
+ Ok(p) => p,
+ Err(err) => {
+ if let PopenError::IoError(ref err) = err {
+ if err.kind() == io::ErrorKind::NotFound {
+ tracing::error!("bloom-inference-server not found in PATH");
+ tracing::error!("Please install it with `make install-server`")
+ }
+ }
+ status_sender
+ .send(ShardStatus::Failed((rank, err.to_string())))
+ .unwrap();
+ return;
+ }
+ };
+
+ let mut ready = false;
+ let start_time = Instant::now();
+ loop {
+ // Process exited
+ if p.poll().is_some() {
+ let mut err = String::new();
+ p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
+ status_sender
+ .send(ShardStatus::Failed((rank, err)))
+ .unwrap();
+ return;
+ }
+
+ // We received a shutdown signal
+ if *shutdown.lock().unwrap() {
+ p.terminate().unwrap();
+ let _ = p.wait_timeout(Duration::from_secs(90));
+ tracing::info!("Shard {} terminated", rank);
+ return;
+ }
+
+ // Shard is ready
+ if uds.exists() && !ready {
+ tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
+ status_sender.send(ShardStatus::Ready).unwrap();
+ ready = true;
+ } else if !ready {
+ tracing::info!("Waiting for shard {} to be ready...", rank);
+ }
+ sleep(Duration::from_secs(5));
+ }
+}
+
+fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receiver<()>) {
+ tracing::info!("Shutting down shards");
+ // Update shutdown value to true
+ // This will be picked up by the shard manager
+ {
+ let mut shutdown = shutdown.lock().unwrap();
+ *shutdown = true;
+ }
+
+ // Wait for shards to shutdown
+ // This will block till all shutdown_sender are dropped
+ let _ = shutdown_receiver.recv();
+}
diff --git a/proto/generate.proto b/proto/generate.proto
index 8c5221b48..45afca82a 100644
--- a/proto/generate.proto
+++ b/proto/generate.proto
@@ -11,10 +11,6 @@ service TextGenerationService {
rpc Generate (GenerateRequest) returns (GenerateResponse);
/// Generate tokens for a list of cached batches
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
- /// Generate tokens until the text of at least one request of the batch is generated
- rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse);
- /// Generate tokens until the text of at least one request of the cached batches i finished
- rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse);
}
/// Empty request
@@ -92,27 +88,3 @@ message GenerateWithCacheResponse {
/// Next batch (cached)
optional Batch batch = 2;
}
-
-message GenerateUntilFinishedRequest {
- /// Batch
- Batch batch = 1;
-}
-
-message GenerateUntilFinishedResponse {
- /// Finished requests
- repeated GeneratedText generated_texts = 1;
- /// Next batch (cached)
- optional Batch batch = 2;
-}
-
-message GenerateUntilFinishedWithCacheRequest {
- /// Cached batches
- repeated Batch batches = 1;
-}
-
-message GenerateUntilFinishedWithCacheResponse {
- /// Finished requests
- repeated GeneratedText generated_texts = 1;
- /// Next batch (cached)
- optional Batch batch = 2;
-}
diff --git a/router/.gitignore b/router/.gitignore
deleted file mode 100644
index ea8c4bf7f..000000000
--- a/router/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-/target
diff --git a/router/Cargo.toml b/router/Cargo.toml
index 37f319e99..5820c1383 100644
--- a/router/Cargo.toml
+++ b/router/Cargo.toml
@@ -22,16 +22,7 @@ serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
-tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
+tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tracing = "0.1.36"
tracing-subscriber = "0.3.15"
-[workspace]
-members = [
- "client",
-]
-
-[profile.release]
-debug = 1
-incremental = true
-lto = "off"
diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml
index 7760c8cbe..633f82a9d 100644
--- a/router/client/Cargo.toml
+++ b/router/client/Cargo.toml
@@ -5,8 +5,6 @@ edition = "2021"
[dependencies]
futures = "^0.3"
-#grpc-error-details = { path = "../../grpc-error-details" }
-#grpc-metadata = { path = "../../grpc-metadata" }
prost = "^0.9"
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }
diff --git a/router/client/src/client.rs b/router/client/src/client.rs
index e7189b892..172d0bf73 100644
--- a/router/client/src/client.rs
+++ b/router/client/src/client.rs
@@ -1,10 +1,11 @@
+/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use tonic::transport::{Channel, Uri};
use tracing::*;
-/// BLOOM Inference gRPC client
+/// Text Generation Inference gRPC client
#[derive(Clone)]
pub struct Client {
stub: TextGenerationServiceClient,
@@ -34,6 +35,7 @@ impl Client {
})
}
+ /// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result> {
let request = tonic::Request::new(ServiceDiscoveryRequest {});
@@ -46,6 +48,7 @@ impl Client {
.into_inner()
.urls
.into_iter()
+ // Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
@@ -54,6 +57,7 @@ impl Client {
Ok(urls)
}
+ /// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {});
@@ -64,6 +68,10 @@ impl Client {
Ok(())
}
+ /// Generate one token for each request in the given batch
+ ///
+ /// Returns a list of generated texts of request that met their stopping criteria
+ /// and the next cached batch
#[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
@@ -76,6 +84,10 @@ impl Client {
Ok((response.generated_texts, response.batch))
}
+ /// Generate one token for each request in the given cached batch
+ ///
+ /// Returns a list of generated texts of request that met their stopping criteria
+ /// and the next cached batch
#[instrument(skip(self))]
pub async fn generate_with_cache(
&mut self,
@@ -90,34 +102,4 @@ impl Client {
.into_inner();
Ok((response.generated_texts, response.batch))
}
-
- #[instrument(skip(self))]
- pub async fn generate_until_finished(
- &mut self,
- batch: Batch,
- ) -> Result<(Vec, Option)> {
- let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
- let response = self
- .stub
- .generate_until_finished(request)
- .instrument(info_span!("generate_until_finished"))
- .await?
- .into_inner();
- Ok((response.generated_texts, response.batch))
- }
-
- #[instrument(skip(self))]
- pub async fn generate_until_finished_with_cache(
- &mut self,
- batches: Vec,
- ) -> Result<(Vec, Option)> {
- let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
- let response = self
- .stub
- .generate_until_finished_with_cache(request)
- .instrument(info_span!("generate_until_finished_with_cache"))
- .await?
- .into_inner();
- Ok((response.generated_texts, response.batch))
- }
}
diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs
index 48b2650d0..0f1f96bca 100644
--- a/router/client/src/lib.rs
+++ b/router/client/src/lib.rs
@@ -1,6 +1,7 @@
-//! BLOOM Inference gRPC client library
+//! Text Generation gRPC client library
mod client;
+#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod sharded_client;
@@ -8,7 +9,7 @@ pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use sharded_client::ShardedClient;
use thiserror::Error;
-pub use tonic::transport;
+use tonic::transport;
use tonic::Status;
#[derive(Error, Debug, Clone)]
@@ -21,7 +22,7 @@ pub enum ClientError {
impl From for ClientError {
fn from(err: Status) -> Self {
- Self::Generation(err.to_string())
+ Self::Generation(err.message().to_string())
}
}
diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs
index 7134551e6..916e72b48 100644
--- a/router/client/src/sharded_client.rs
+++ b/router/client/src/sharded_client.rs
@@ -1,9 +1,11 @@
+/// Multi shard Client
use crate::Result;
use crate::{Batch, Client, GeneratedText};
use futures::future::join_all;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;
+/// List of all available commands that can be sent through the command channel
#[derive(Clone, Debug)]
enum Command {
Generate(
@@ -14,36 +16,32 @@ enum Command {
Vec,
mpsc::Sender, Option)>>,
),
- GenerateUntilFinished(
- Batch,
- mpsc::Sender, Option)>>,
- ),
- GenerateUntilFinishedWithCache(
- Vec,
- mpsc::Sender, Option)>>,
- ),
ClearCache(mpsc::Sender>),
}
+/// Tokio task that handles the communication with a single shard
+///
+/// We subscribe on a broadcast channel to receive commands that will be sent by
+/// the ShardedClient.
+///
+/// Each command is fan out to all shards.
+///
+/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
+/// producer = the shards, single consumer = the ShardedClient).
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver) {
while let Ok(message) = request_subscriber.recv().await {
match message {
Command::Generate(batch, response_tx) => {
let result = client.generate(batch).await;
+ // We can unwrap_or(()) here because the only error that can happen is if the
+ // receiver is dropped, which means that the ShardedClient already received a
+ // response from another shard
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateWithCache(batches, response_tx) => {
let result = client.generate_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(());
}
- Command::GenerateUntilFinished(batch, response_tx) => {
- let result = client.generate_until_finished(batch).await;
- response_tx.try_send(result).unwrap_or(());
- }
- Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
- let result = client.generate_until_finished_with_cache(batches).await;
- response_tx.try_send(result).unwrap_or(());
- }
Command::ClearCache(response_tx) => {
let result = client.clear_cache().await;
response_tx.try_send(result).unwrap_or(());
@@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
}
}
+/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
+ _clients: Vec,
request_tx: broadcast::Sender,
}
impl ShardedClient {
- fn new(mut clients: Vec) -> Self {
+ fn new(clients: Vec) -> Self {
+ // The broadcast channel to communicate with the shards
+ // We use a capacity of one as the shards are not asynchronous and can only process one
+ // command at a time
let (request_tx, _) = broadcast::channel(1);
- for client in clients.drain(..) {
+ // Spawn client tasks
+ for client in clients.iter() {
let request_subscriber = request_tx.subscribe();
- tokio::spawn(client_task(client, request_subscriber));
+ tokio::spawn(client_task(client.clone(), request_subscriber));
}
- Self { request_tx }
+ Self {
+ _clients: clients,
+ request_tx,
+ }
}
+ /// Create a new ShardedClient from a master client. The master client will communicate with
+ /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result {
+ // Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await.unwrap();
- let futures = uris.into_iter().map(|path| Client::connect_uds(path));
+ let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
- /// Returns a client connected to the given url
+ /// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
@@ -87,51 +97,43 @@ impl ShardedClient {
Self::from_master_client(master_client).await
}
+ /// Generate one token for each request in the given batch
+ ///
+ /// Returns a list of generated texts of request that met their stopping criteria
+ /// and the next cached batch
pub async fn generate(&self, batch: Batch) -> Result<(Vec, Option)> {
+ // Create a channel to receive the response from the shards
+ // We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::Generate(batch, response_tx))
.unwrap();
+ // As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
}
+ /// Generate one token for each request in the given cached batch
+ ///
+ /// Returns a list of generated texts of request that met their stopping criteria
+ /// and the next cached batch
pub async fn generate_with_cache(
&self,
batches: Vec,
) -> Result<(Vec, Option)> {
+ // Create a channel to receive the response from the shards
+ // We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateWithCache(batches, response_tx))
.unwrap();
+ // As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
}
- pub async fn generate_until_finished(
- &self,
- batch: Batch,
- ) -> Result<(Vec, Option)> {
- let (response_tx, mut response_rx) = mpsc::channel(1);
- self.request_tx
- .send(Command::GenerateUntilFinished(batch, response_tx))
- .unwrap();
- response_rx.recv().await.unwrap()
- }
-
- pub async fn generate_until_finished_with_cache(
- &self,
- batches: Vec,
- ) -> Result<(Vec, Option)> {
- let (response_tx, mut response_rx) = mpsc::channel(1);
- self.request_tx
- .send(Command::GenerateUntilFinishedWithCache(
- batches,
- response_tx,
- ))
- .unwrap();
- response_rx.recv().await.unwrap()
- }
-
+ /// Clear the past generations cache
pub async fn clear_cache(&self) -> Result<()> {
+ // Create a channel to receive the response from the shards
+ // We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::ClearCache(response_tx))
diff --git a/router/src/batcher.rs b/router/src/batcher.rs
index ebd817300..4523dbfff 100644
--- a/router/src/batcher.rs
+++ b/router/src/batcher.rs
@@ -1,129 +1,158 @@
-use crate::server::GenerateRequest;
+/// Batching and inference logic
+use crate::GenerateRequest;
use crate::{Db, Entry};
use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
+use std::time::Duration;
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
+use tokio::time::Instant;
+use tracing::instrument;
-const MAX_LENGTH: usize = 128;
-
-#[derive(Debug, Error)]
-pub enum InferError {
- #[error("Request failed during generation: {0}")]
- GenerationError(String),
- #[error("Model is overloaded")]
- Overloaded,
-}
-
-impl From for (StatusCode, String) {
- fn from(err: InferError) -> Self {
- match err {
- InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
- InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
- }
- }
-}
-
+/// Batcher
#[derive(Clone)]
pub struct Batcher {
+ /// Request database
db: Db,
+ /// Shared state
shared: Arc,
}
+/// Batcher shared state
struct Shared {
+ /// Batching background Tokio task notifier
batching_task: Notify,
}
impl Batcher {
- pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
+ pub(crate) fn new(
+ client: ShardedClient,
+ max_batch_size: usize,
+ max_waiting_time: Duration,
+ ) -> Self {
+ // Batcher shared state
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
- tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
+ // Spawn batching background task that contains all the inference logic
+ tokio::spawn(batching_task(
+ max_batch_size,
+ max_waiting_time,
+ client,
+ db.clone(),
+ shared.clone(),
+ ));
Self { db, shared }
}
+ /// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer(
&self,
input_length: usize,
request: GenerateRequest,
) -> Result {
- if self.db.len() > MAX_LENGTH {
- return Err(InferError::Overloaded);
- }
- let (request_tx, request_rx) = oneshot::channel();
+ // One shot channel to communicate with the background batching task
+ let (response_tx, response_rx) = oneshot::channel();
+
+ // Try to append the request to the database
self.db.append(Entry {
request,
- response_tx: request_tx,
+ response_tx,
input_length,
+ time: Instant::now(),
});
+
+ // Notify the background task that we have a new entry in the database that needs
+ // to be batched
self.shared.batching_task.notify_waiters();
- match request_rx.await.unwrap() {
+
+ // Await on the response from the background task
+ // We can safely unwrap as the background task will never drop the sender
+ match response_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(err) => Err(InferError::GenerationError(err.to_string())),
}
}
}
-async fn batching_task(max_batch_size: usize,
- client: ShardedClient,
- db: Db,
- shared: Arc) {
+/// Batching logic
+/// Will be launched in a background Tokio task
+///
+/// Batches requests and sends them to the inference server
+#[instrument(skip(client, db, shared))]
+async fn batching_task(
+ max_batch_size: usize,
+ max_waiting_time: Duration,
+ client: ShardedClient,
+ db: Db,
+ shared: Arc,
+) {
+ // Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32;
+ // Infinite loop
loop {
+ // Wait for a notification from the Batcher struct
shared.batching_task.notified().await;
- if let Some(batch) = db.next_batch(max_batch_size) {
- let request_ids = batch.requests.iter().map(|req| req.id).collect();
- let mut cached_batch = match batch.size {
- size if size > limit_min_batch_size => {
- wrap_future(client.generate_until_finished(batch), request_ids, &db).await
- }
- _ => wrap_future(client.generate(batch), request_ids, &db).await,
- };
+ // Get the next batch from the DB
+ // This batch might be smaller than the maximum batch size if there are not enough requests
+ // waiting in the DB
+ if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
+ let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
+ // We loop until we do not receive any cached batch from the inference server (== until
+ // all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
- let mut current_batch_size = batch.size;
+ // Get current batch info
+ let batch_size = batch.size;
let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch];
- if current_batch_size <= limit_min_batch_size {
- if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
- let new_batch_request_ids =
- new_batch.requests.iter().map(|req| req.id).collect();
+ // If the current batch is too small, we try to add more requests to it
+ if batch_size <= limit_min_batch_size {
+ // Get the next batch from the DB that meet our minimum size criteria
+ if let Some((new_request_ids, new_batch)) =
+ db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
+ {
+ // Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
- wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
- .await;
+ wrap_future(client.generate(new_batch), new_request_ids, &db).await;
+ // Extend current batch with the new batch
+ if let Some(new_cached_batch) = new_cached_batch {
+ request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
+ batches.push(new_cached_batch);
+ }
+ }
+ // If we don't have enough requests to meet the minimum size criteria, we
+ // try to get the next batch from the DB that have been waiting over
+ // the max_waiting_time
+ else if let Some((new_request_ids, new_batch)) =
+ db.next_batch(None, max_batch_size, Some(max_waiting_time))
+ {
+ let new_cached_batch =
+ wrap_future(client.generate(new_batch), new_request_ids, &db).await;
+ // Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
- current_batch_size += new_cached_batch.size;
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
}
- cached_batch = match current_batch_size {
- size if size > limit_min_batch_size => {
- wrap_future(
- client.generate_until_finished_with_cache(batches),
- request_ids,
- &db,
- )
- .await
- }
- _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
- };
+ cached_batch =
+ wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
}
}
}
}
+/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future(
future: impl Future