This commit is contained in:
OlivierDehaene 2023-03-12 10:05:33 +01:00
parent c9bdaa8b73
commit 4dfa6fbb62
9 changed files with 839 additions and 14 deletions

514
Cargo.lock generated
View File

@ -28,6 +28,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
@ -192,6 +201,12 @@ version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "bytemuck"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea"
[[package]]
name = "byteorder"
version = "1.4.3"
@ -260,6 +275,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f"
dependencies = [
"iana-time-zone",
"js-sys",
"num-integer",
"num-traits",
"time",
"wasm-bindgen",
"winapi",
]
[[package]]
name = "clap"
version = "2.34.0"
@ -312,6 +342,31 @@ dependencies = [
"os_str_bytes",
]
[[package]]
name = "cmake"
version = "0.1.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db34956e100b30725f2eb215f90d4871051239535632f84fea3bc92722c66b7c"
dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]]
name = "color_quant"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "console"
version = "0.15.5"
@ -325,6 +380,12 @@ dependencies = [
"windows-sys 0.42.0",
]
[[package]]
name = "const-cstr"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed3d0b5ff30645a68f35ece8cea4556ca14ef8a1651455f789a099a0513532a6"
[[package]]
name = "core-foundation"
version = "0.9.3"
@ -341,6 +402,43 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
[[package]]
name = "core-graphics"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb"
dependencies = [
"bitflags",
"core-foundation",
"core-graphics-types",
"foreign-types",
"libc",
]
[[package]]
name = "core-graphics-types"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a68b68b3446082644c91ac778bf50cd4104bfb002b5a6a7c44cca5a2c70788b"
dependencies = [
"bitflags",
"core-foundation",
"foreign-types",
"libc",
]
[[package]]
name = "core-text"
version = "19.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d74ada66e07c1cefa18f8abfba765b486f250de2e4a999e5727fc0dd4b4a25"
dependencies = [
"core-foundation",
"core-graphics",
"foreign-types",
"libc",
]
[[package]]
name = "cpufeatures"
version = "0.2.5"
@ -422,6 +520,50 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "cxx"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a140f260e6f3f79013b8bfc65e7ce630c9ab4388c6a89c71e07226f49487b72"
dependencies = [
"cc",
"cxxbridge-flags",
"cxxbridge-macro",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da6383f459341ea689374bf0a42979739dc421874f112ff26f829b8040b8e613"
dependencies = [
"cc",
"codespan-reporting",
"once_cell",
"proc-macro2",
"quote",
"scratch",
"syn",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90201c1a650e95ccff1c8c0bb5a343213bdd317c6e600a93075bca2eff54ec97"
[[package]]
name = "cxxbridge-macro"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b75aed41bb2e6367cae39e6326ef817a851db13c13e4f3263714ca3cfb8de56"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "darling"
version = "0.10.2"
@ -523,6 +665,16 @@ dependencies = [
"dirs-sys",
]
[[package]]
name = "dirs-next"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
dependencies = [
"cfg-if",
"dirs-sys-next",
]
[[package]]
name = "dirs-sys"
version = "0.3.7"
@ -534,6 +686,38 @@ dependencies = [
"winapi",
]
[[package]]
name = "dirs-sys-next"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d"
dependencies = [
"libc",
"redox_users",
"winapi",
]
[[package]]
name = "dlib"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac1b7517328c04c2aa68422fc60a41b92208182142ed04a25879c26c8f878794"
dependencies = [
"libloading",
]
[[package]]
name = "dwrote"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439a1c2ba5611ad3ed731280541d36d2e9c4ac5e7fb818a27b604bdc5a6aa65b"
dependencies = [
"lazy_static",
"libc",
"winapi",
"wio",
]
[[package]]
name = "either"
version = "1.8.1"
@ -622,6 +806,12 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "float-ord"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e"
[[package]]
name = "float_eq"
version = "1.0.1"
@ -634,6 +824,31 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "font-kit"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21fe28504d371085fae9ac7a3450f0b289ab71e07c8e57baa3fb68b9e57d6ce5"
dependencies = [
"bitflags",
"byteorder",
"core-foundation",
"core-graphics",
"core-text",
"dirs-next",
"dwrote",
"float-ord",
"freetype",
"lazy_static",
"libc",
"log",
"pathfinder_geometry",
"pathfinder_simd",
"walkdir",
"winapi",
"yeslogic-fontconfig-sys",
]
[[package]]
name = "foreign-types"
version = "0.3.2"
@ -658,6 +873,27 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "freetype"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bee38378a9e3db1cc693b4f88d166ae375338a0ff75cb8263e1c601d51f35dc6"
dependencies = [
"freetype-sys",
"libc",
]
[[package]]
name = "freetype-sys"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a37d4011c0cc628dfa766fcc195454f4b068d7afdc2adfd28861191d866e731a"
dependencies = [
"cmake",
"libc",
"pkg-config",
]
[[package]]
name = "fs2"
version = "0.4.3"
@ -778,6 +1014,16 @@ dependencies = [
"wasi 0.11.0+wasi-snapshot-preview1",
]
[[package]]
name = "gif"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3edd93c6756b4dfaf2709eafcc345ba2636565295c198a9cfbf75fa5e3e00b06"
dependencies = [
"color_quant",
"weezl",
]
[[package]]
name = "glob"
version = "0.3.1"
@ -941,6 +1187,30 @@ dependencies = [
"tokio-native-tls",
]
[[package]]
name = "iana-time-zone"
version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"winapi",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca"
dependencies = [
"cxx",
"cxx-build",
]
[[package]]
name = "ident_case"
version = "1.0.1"
@ -957,6 +1227,21 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "image"
version = "0.24.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69b7ea949b537b0fd0af141fff8c77690f2ce96f4f41f042ccb6c69c6c965945"
dependencies = [
"bytemuck",
"byteorder",
"color_quant",
"jpeg-decoder",
"num-rational",
"num-traits",
"png",
]
[[package]]
name = "indexmap"
version = "1.9.2"
@ -1062,6 +1347,12 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]]
name = "jpeg-decoder"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc0000e42512c92e31c2252315bda326620a4e034105e900c98ec492fa077b3e"
[[package]]
name = "js-sys"
version = "0.3.61"
@ -1083,6 +1374,25 @@ version = "0.2.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
[[package]]
name = "libloading"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
dependencies = [
"cfg-if",
"winapi",
]
[[package]]
name = "link-cplusplus"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecd207c9c713c34f95a097a5b029ac2ce6010530c7b49d7fea24d977dede04f5"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
@ -1326,6 +1636,36 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.15.0"
@ -1542,12 +1882,41 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
[[package]]
name = "pathfinder_geometry"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b7e7b4ea703700ce73ebf128e1450eb69c3a8329199ffbfb9b2a0418e5ad3"
dependencies = [
"log",
"pathfinder_simd",
]
[[package]]
name = "pathfinder_simd"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39fe46acc5503595e5949c17b818714d26fdf9b4920eacf3b2947f0199f4a6ff"
dependencies = [
"rustc_version",
]
[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]]
name = "pest"
version = "2.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cbd939b234e95d72bc393d51788aec68aeeb5d51e748ca08ff3aad58cb722f7"
dependencies = [
"thiserror",
"ucd-trie",
]
[[package]]
name = "petgraph"
version = "0.6.3"
@ -1596,6 +1965,64 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "plotters"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97"
dependencies = [
"chrono",
"font-kit",
"image",
"lazy_static",
"num-traits",
"pathfinder_geometry",
"plotters-backend",
"plotters-bitmap",
"plotters-svg",
"ttf-parser",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "193228616381fecdc1224c62e96946dfbc73ff4384fba576e052ff8c1bea8142"
[[package]]
name = "plotters-bitmap"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4a1f21490a6cf4a84c272ad20bd7844ed99a3178187a4c5ab7f2051295beef"
dependencies = [
"gif",
"image",
"plotters-backend",
]
[[package]]
name = "plotters-svg"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9a81d2759aae1dae668f783c308bc5c8ebd191ff4184aaa1b37f65a6ae5a56f"
dependencies = [
"plotters-backend",
]
[[package]]
name = "png"
version = "0.17.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d708eaf860a19b19ce538740d2b4bdeeb8337fa53f7738455e706623ad5c638"
dependencies = [
"bitflags",
"crc32fast",
"flate2",
"miniz_oxide",
]
[[package]]
name = "portable-atomic"
version = "0.3.19"
@ -1920,6 +2347,15 @@ dependencies = [
"walkdir",
]
[[package]]
name = "rustc_version"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.36.8"
@ -1970,6 +2406,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "scratch"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1792db035ce95be60c3f8853017b3999209281c24e2ba5bc8e59bf97a0c590c1"
[[package]]
name = "security-framework"
version = "2.8.2"
@ -1993,6 +2435,24 @@ dependencies = [
"libc",
]
[[package]]
name = "semver"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6"
dependencies = [
"semver-parser",
]
[[package]]
name = "semver-parser"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7"
dependencies = [
"pest",
]
[[package]]
name = "serde"
version = "1.0.152"
@ -2210,6 +2670,21 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "text-generation-benchmark"
version = "0.1.0"
dependencies = [
"clap 4.1.8",
"itertools 0.10.5",
"plotters",
"text-generation-client",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "text-generation-client"
version = "0.4.1"
@ -2680,12 +3155,24 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "ttf-parser"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd"
[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "ucd-trie"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e79c4d996edb816c91e4308506774452e55e95c3c9de07b6729e17e15a5ef81"
[[package]]
name = "unicase"
version = "2.6.0"
@ -2927,6 +3414,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "weezl"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb"
[[package]]
name = "which"
version = "4.4.0"
@ -3059,6 +3552,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "wio"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5"
dependencies = [
"winapi",
]
[[package]]
name = "xattr"
version = "0.2.3"
@ -3068,6 +3570,18 @@ dependencies = [
"libc",
]
[[package]]
name = "yeslogic-fontconfig-sys"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2bbd69036d397ebbff671b1b8e4d918610c181c5a16073b96f984a38d08c386"
dependencies = [
"const-cstr",
"dlib",
"once_cell",
"pkg-config",
]
[[package]]
name = "zip"
version = "0.5.13"

View File

@ -1,5 +1,6 @@
[workspace]
members = [
"benchmark",
"router",
"router/client",
"router/grpc-metadata",

View File

@ -8,7 +8,7 @@ environment_variables:
MODEL_ID: bigscience/bloom
NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.4.0
inference_config:
liveness_route:
port: 80
@ -18,21 +18,21 @@ environment:
path: /health
scoring_route:
port: 80
path: /generate
path: /
instance_type: Standard_ND96amsr_A100_v4
request_settings:
request_timeout_ms: 90000
request_timeout_ms: 60000
max_concurrent_requests_per_instance: 256
liveness_probe:
initial_delay: 600
timeout: 90
period: 120
initial_delay: 90
timeout: 20
period: 60
success_threshold: 1
failure_threshold: 5
failure_threshold: 2
readiness_probe:
initial_delay: 600
timeout: 90
period: 120
initial_delay: 90
timeout: 60
period: 60
success_threshold: 1
failure_threshold: 5
failure_threshold: 2
instance_count: 1

25
benchmark/Cargo.toml Normal file
View File

@ -0,0 +1,25 @@
[package]
name = "text-generation-benchmark"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Benchmarking tool"
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-bench"
path = "src/main.rs"
[dependencies]
clap = { version = "4.1.4", features = ["derive", "env"] }
itertools = "0.10.5"
plotters = "0.3.4"
text-generation-client = { path = "../router/client" }
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

178
benchmark/src/lib.rs Normal file
View File

@ -0,0 +1,178 @@
use std::time::Duration;
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::time;
use text_generation_client::{ShardedClient, Request, Batch, StoppingCriteriaParameters, NextTokenChooserParameters};
use time::Instant;
use plotters::prelude::*;
use itertools::Itertools;
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
const OUT_FILE_NAME: &'static str = "errorbar.png";
enum Step {
Prefill,
Decode,
}
struct Run {
step: Step,
batch_size: u32,
sequence_length: u32,
decode_length: u32,
time: Duration,
}
pub async fn run(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: Vec<u32>,
decode_length: Vec<u32>,
runs: u32,
mut client: ShardedClient,
) -> Result<(), Box<dyn std::error::Error>> {
// let prefill_runs = benchmark_prefill(&tokenizer, &batch_size, &sequence_length, &decode_length, runs, &mut client).await;
let mut runs: Vec<(f64, f64)> = Vec::new();
for i in 0..10{
for j in 0..10 {
runs.push((i as f64, j as f64));
}
}
let data = runs;
// let down_sampled = down_sample(&data[..]);
let root = BitMapBackend::new(OUT_FILE_NAME, (1024, 768)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("Linear Function with Noise", ("sans-serif", 60))
.margin(10)
.set_label_area_size(LabelAreaPosition::Left, 40)
.set_label_area_size(LabelAreaPosition::Bottom, 40)
.build_cartesian_2d(-10f64..10f64, -10f64..10f64)?;
chart.configure_mesh().draw()?;
chart
.draw_series(LineSeries::new(data, &GREEN.mix(0.3)))?
.label("Raw Data");
// chart.draw_series(LineSeries::new(
// down_sampled.iter().map(|(x, _, y, _)| (*x, *y)),
// &BLUE,
// ))?;
// chart
// .draw_series(
// down_sampled.iter().map(|(x, yl, ym, yh)| {
// ErrorBar::new_vertical(*x, *yl, *ym, *yh, BLUE.filled(), 20)
// }),
// )?
// .label("Down-sampled")
// .legend(|(x, y)| PathElement::new(vec![(x, y), (x, y)], &BLUE));
chart
.configure_series_labels()
.background_style(WHITE.filled())
.draw()?;
// To avoid the IO failure being ignored silently, we manually call the present function
root.present().expect("Unable to write result to file, please make sure 'plotters-doc-data' dir exists under current dir");
println!("Result has been saved to {}", OUT_FILE_NAME);
Ok(())
}
// fn down_sample(data: &[(f64, f64)]) -> Vec<(f64, f64, f64, f64)> {
// let down_sampled: Vec<_> = data
// .iter()
// .into_iter()
// .map(|(x, g)| {
// let mut g: Vec<_> = g.map(|(_, y)| *y).collect();
// g.sort_by(|a, b| a.partial_cmp(b).unwrap());
// (
// x,
// g[0],
// g.iter().sum::<f64>() / g.len() as f64,
// g[g.len() - 1],
// )
// })
// .collect();
// down_sampled
// }
async fn benchmark_prefill(tokenizer: &Tokenizer,
batch_size: &Vec<u32>,
sequence_length: &Vec<u32>,
decode_length: &Vec<u32>,
runs: u32,
client: &mut ShardedClient) -> Vec<Run> {
let mut results = Vec::new();
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
for s in sequence_length {
let sequence = create_sequence(s, lorem_ipsum_length, tokenizer);
for b in batch_size {
for d in decode_length {
let requests = (0..*b).map(|id| {
Request {
id: id.into(),
inputs: sequence.clone(),
input_length: *s,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: *d,
stop_sequences: vec![],
ignore_eos_token: true,
}),
}
}).collect();
let batch = Batch {
id: 0,
requests,
size: *b,
};
for _ in 0..runs {
let start_time = Instant::now();
client.prefill(batch.clone()).await.unwrap();
let elasped = start_time.elapsed();
client.clear_cache().await.unwrap();
results.push(Run {
step: Step::Prefill,
batch_size: *b,
sequence_length: *s,
decode_length: *d,
time: elasped,
});
}
}
}
}
results
}
fn create_sequence(sequence_length: &u32, lorem_ipsum_length: usize, tokenizer: &Tokenizer) -> String {
// Repeat lorem ipsum to cover sequence length
let string_sequence = LOREM_IPSUM.repeat((0..*sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(*sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer.decode(Vec::from(encoding.get_ids()), false).unwrap()
}

101
benchmark/src/main.rs Normal file
View File

@ -0,0 +1,101 @@
/// Text Generation Inference benchmarking tool
use std::path::Path;
use clap::Parser;
use tokenizers::Tokenizer;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use text_generation_client::ShardedClient;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(default_value = "1", long, env)]
batch_size: Vec<u32>,
#[clap(default_value = "128", long, env)]
sequence_length: Vec<u32>,
#[clap(default_value = "100", long, env)]
decode_length: Vec<u32>,
#[clap(default_value = "1", long, env)]
runs: u32,
#[clap(default_value = "/tmp/text-generation-0", long, env)]
master_shard_uds_path: String,
}
fn main() -> Result<(), std::io::Error> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
tokenizer_name,
batch_size,
sequence_length,
decode_length,
runs,
master_shard_uds_path,
} = args;
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
init_logging();
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache()
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
text_generation_benchmark::run(
tokenizer,
batch_size,
sequence_length,
decode_length,
runs,
sharded_client,
).await;
});
Ok(())
}
/// Init logging using LOG_LEVEL
fn init_logging() {
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer)
.init();
}

View File

@ -53,6 +53,9 @@ message StoppingCriteriaParameters {
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {

View File

@ -315,6 +315,7 @@ fn validate(
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false
};
metrics::histogram!("tgi_request_input_length", input_length as f64);

View File

@ -123,20 +123,22 @@ class StoppingCriteria:
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
max_new_tokens: int=20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.ignore_eos_token = ignore_eos_token
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
if not self.ignore_eos_token and last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
@ -156,5 +158,5 @@ class StoppingCriteria:
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token
)