Merge branch 'main' into gptq-cuda-kernels

This commit is contained in:
Félix Marty 2023-07-19 16:58:54 +02:00
commit edfbfdfb3f
30 changed files with 968 additions and 563 deletions

302
Cargo.lock generated
View File

@ -148,18 +148,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.68" version = "0.1.71"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -410,9 +410,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.3.10" version = "4.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384e169cc618c613d5e3ca6404dda77a8685a63e08660dcc64abaf7da7cb0c7a" checksum = "1640e5cc7fb47dbb8338fd471b105e7ed6c3cb2aeb00c2e067127ffd3764a05d"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -421,9 +421,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.3.10" version = "4.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef137bbe35aab78bdb468ccfba75a5f4d8321ae011d34063770780545176af2d" checksum = "98c59138d527eeaf9b53f35a77fcc1fad9d883116070c63d5de1c7dc7b00c72b"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -440,7 +440,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -492,9 +492,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.8" version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c" checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -633,12 +633,12 @@ dependencies = [
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "5.4.0" version = "5.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"hashbrown 0.12.3", "hashbrown 0.14.0",
"lock_api", "lock_api",
"once_cell", "once_cell",
"parking_lot_core", "parking_lot_core",
@ -736,6 +736,12 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]] [[package]]
name = "errno" name = "errno"
version = "0.3.1" version = "0.3.1"
@ -924,7 +930,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -1023,7 +1029,7 @@ dependencies = [
"futures-sink", "futures-sink",
"futures-util", "futures-util",
"http", "http",
"indexmap", "indexmap 1.9.3",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -1038,13 +1044,19 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.13.2" version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038"
dependencies = [ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -1053,9 +1065,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.3.1" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b"
[[package]] [[package]]
name = "hmac" name = "hmac"
@ -1190,6 +1202,16 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"hashbrown 0.12.3", "hashbrown 0.12.3",
]
[[package]]
name = "indexmap"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d"
dependencies = [
"equivalent",
"hashbrown 0.14.0",
"serde", "serde",
] ]
@ -1254,12 +1276,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.8" version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi",
"rustix 0.38.1", "rustix 0.38.4",
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
@ -1292,9 +1314,9 @@ dependencies = [
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.6" version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a"
[[package]] [[package]]
name = "jobserver" name = "jobserver"
@ -1397,7 +1419,7 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [ dependencies = [
"regex-automata", "regex-automata 0.1.10",
] ]
[[package]] [[package]]
@ -1432,9 +1454,9 @@ dependencies = [
[[package]] [[package]]
name = "metrics" name = "metrics"
version = "0.21.0" version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa8ebbd1a9e57bbab77b9facae7f5136aea44c356943bf9a198f647da64285d6" checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
dependencies = [ dependencies = [
"ahash", "ahash",
"metrics-macros", "metrics-macros",
@ -1449,7 +1471,7 @@ checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5"
dependencies = [ dependencies = [
"base64 0.21.2", "base64 0.21.2",
"hyper", "hyper",
"indexmap", "indexmap 1.9.3",
"ipnet", "ipnet",
"metrics", "metrics",
"metrics-util", "metrics-util",
@ -1467,18 +1489,18 @@ checksum = "ddece26afd34c31585c74a4db0630c376df271c285d682d1e55012197830b6df"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
name = "metrics-util" name = "metrics-util"
version = "0.15.0" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "111cb375987443c3de8d503580b536f77dc8416d32db62d9456db5d93bd7ac47" checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e"
dependencies = [ dependencies = [
"crossbeam-epoch", "crossbeam-epoch",
"crossbeam-utils", "crossbeam-utils",
"hashbrown 0.13.2", "hashbrown 0.13.1",
"metrics", "metrics",
"num_cpus", "num_cpus",
"quanta", "quanta",
@ -1530,9 +1552,9 @@ dependencies = [
[[package]] [[package]]
name = "monostate" name = "monostate"
version = "0.1.6" version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" checksum = "3f3f57a8802842f648026a33c3d2e3bb41bb309a35b1609bd7ef2b060b8b6b1b"
dependencies = [ dependencies = [
"monostate-impl", "monostate-impl",
"serde", "serde",
@ -1540,13 +1562,13 @@ dependencies = [
[[package]] [[package]]
name = "monostate-impl" name = "monostate-impl"
version = "0.1.6" version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" checksum = "e72f4d2e10fde62a0f2fcb4b44ccbf4f9899dcc30c9193449f8dfb9123d71377"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -1701,6 +1723,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "num_threads"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "number_prefix" name = "number_prefix"
version = "0.3.0" version = "0.3.0"
@ -1773,7 +1804,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -1854,7 +1885,7 @@ dependencies = [
"fnv", "fnv",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"indexmap", "indexmap 1.9.3",
"js-sys", "js-sys",
"once_cell", "once_cell",
"pin-project-lite", "pin-project-lite",
@ -1870,7 +1901,7 @@ dependencies = [
"fnv", "fnv",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"indexmap", "indexmap 1.9.3",
"once_cell", "once_cell",
"pin-project-lite", "pin-project-lite",
"thiserror", "thiserror",
@ -1974,9 +2005,9 @@ dependencies = [
[[package]] [[package]]
name = "paste" name = "paste"
version = "1.0.12" version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" checksum = "b4b27ab7be369122c218afc2079489cdcb4b517c0a3fc386ff11e1fedfcc2b35"
[[package]] [[package]]
name = "pbkdf2" name = "pbkdf2"
@ -2003,34 +2034,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4"
dependencies = [ dependencies = [
"fixedbitset", "fixedbitset",
"indexmap", "indexmap 1.9.3",
] ]
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.1" version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e138fdd8263907a2b0e1b4e80b7e58c721126479b6e6eedfb1b402acea7b9bd" checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842"
dependencies = [ dependencies = [
"pin-project-internal", "pin-project-internal",
] ]
[[package]] [[package]]
name = "pin-project-internal" name = "pin-project-internal"
version = "1.1.1" version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1fef411b303e3e12d534fb6e7852de82da56edd937d895125821fb7c09436c7" checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" checksum = "4c40d25201921e5ff0c862a505c6557ea88568a4e3ace775ab55e93f2f4f9d57"
[[package]] [[package]]
name = "pin-utils" name = "pin-utils"
@ -2046,9 +2077,9 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.3.3" version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "767eb9f07d4a5ebcb39bbf2d452058a93c011373abf6832e24194a1c3f004794" checksum = "d220334a184db82b31b83f5ff093e3315280fb2b6bbc032022b2304a509aab7a"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
@ -2092,9 +2123,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.63" version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -2294,13 +2325,14 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.8.4" version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575"
dependencies = [ dependencies = [
"aho-corasick 1.0.2", "aho-corasick 1.0.2",
"memchr", "memchr",
"regex-syntax 0.7.2", "regex-automata 0.3.3",
"regex-syntax 0.7.4",
] ]
[[package]] [[package]]
@ -2312,6 +2344,17 @@ dependencies = [
"regex-syntax 0.6.29", "regex-syntax 0.6.29",
] ]
[[package]]
name = "regex-automata"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310"
dependencies = [
"aho-corasick 1.0.2",
"memchr",
"regex-syntax 0.7.4",
]
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.29" version = "0.6.29"
@ -2320,9 +2363,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.7.2" version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
@ -2397,7 +2440,7 @@ dependencies = [
"quote", "quote",
"rust-embed-utils", "rust-embed-utils",
"shellexpand", "shellexpand",
"syn 2.0.22", "syn 2.0.25",
"walkdir", "walkdir",
] ]
@ -2428,9 +2471,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.37.21" version = "0.37.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"errno", "errno",
@ -2442,9 +2485,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.1" version = "0.38.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5"
dependencies = [ dependencies = [
"bitflags 2.3.3", "bitflags 2.3.3",
"errno", "errno",
@ -2476,15 +2519,15 @@ dependencies = [
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.12" version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" checksum = "dc31bd9b61a32c31f9650d18add92aa83a49ba979c143eefd27fe7177b05bd5f"
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.13" version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" checksum = "fe232bdf6be8c8de797b22184ee71118d63780ea42ac85b61d1baa6d3b782ae9"
[[package]] [[package]]
name = "same-file" name = "same-file"
@ -2497,11 +2540,11 @@ dependencies = [
[[package]] [[package]]
name = "schannel" name = "schannel"
version = "0.1.21" version = "0.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88"
dependencies = [ dependencies = [
"windows-sys 0.42.0", "windows-sys 0.48.0",
] ]
[[package]] [[package]]
@ -2551,29 +2594,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.164" version = "1.0.171"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.164" version = "1.0.171"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.99" version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" checksum = "b5062a995d481b2308b6064e9af76011f2921c35f97b0468811ed9f6cd91dfed"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -2582,10 +2625,11 @@ dependencies = [
[[package]] [[package]]
name = "serde_path_to_error" name = "serde_path_to_error"
version = "0.1.11" version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0" checksum = "8acc4422959dd87a76cb117c191dcbffc20467f06c9100b76721dab370f24d3a"
dependencies = [ dependencies = [
"itoa",
"serde", "serde",
] ]
@ -2697,9 +2741,9 @@ dependencies = [
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.10.0" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]] [[package]]
name = "socket2" name = "socket2"
@ -2769,9 +2813,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.22" version = "2.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" checksum = "15e3fc8c0c74267e2df136e5e5fb656a464158aa57624053375eb9c8c6e25ae2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2786,9 +2830,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]] [[package]]
name = "sysinfo" name = "sysinfo"
version = "0.29.3" version = "0.29.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bcd0346f90b6bc83526c7b180039a8acd26a5c848cc556d457f6472eb148122" checksum = "751e810399bba86e9326f5762b7f32ac5a085542df78da6a78d94e07d14d7c11"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"core-foundation-sys", "core-foundation-sys",
@ -2824,9 +2868,9 @@ dependencies = [
[[package]] [[package]]
name = "tar" name = "tar"
version = "0.4.38" version = "0.4.39"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" checksum = "ec96d2ffad078296368d46ff1cb309be1c23c513b4ab0e22a45de0185275ac96"
dependencies = [ dependencies = [
"filetime", "filetime",
"libc", "libc",
@ -2843,13 +2887,13 @@ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
"redox_syscall 0.3.5", "redox_syscall 0.3.5",
"rustix 0.37.21", "rustix 0.37.23",
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "0.9.1" version = "0.9.3"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -2869,7 +2913,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.9.1" version = "0.9.3"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2885,7 +2929,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.9.1" version = "0.9.3"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2901,7 +2945,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.9.1" version = "0.9.3"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -2934,22 +2978,22 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.40" version = "1.0.43"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.40" version = "1.0.43"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -2964,11 +3008,13 @@ dependencies = [
[[package]] [[package]]
name = "time" name = "time"
version = "0.3.22" version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
dependencies = [ dependencies = [
"itoa", "itoa",
"libc",
"num_threads",
"serde", "serde",
"time-core", "time-core",
"time-macros", "time-macros",
@ -2982,9 +3028,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]] [[package]]
name = "time-macros" name = "time-macros"
version = "0.2.9" version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
dependencies = [ dependencies = [
"time-core", "time-core",
] ]
@ -3078,7 +3124,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -3209,7 +3255,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-util", "futures-util",
"indexmap", "indexmap 1.9.3",
"pin-project", "pin-project",
"pin-project-lite", "pin-project-lite",
"rand", "rand",
@ -3291,7 +3337,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
] ]
[[package]] [[package]]
@ -3413,9 +3459,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.9" version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73"
[[package]] [[package]]
name = "unicode-normalization" name = "unicode-normalization"
@ -3484,11 +3530,11 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]] [[package]]
name = "utoipa" name = "utoipa"
version = "3.3.0" version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98" checksum = "520434cac5c98120177d5cc15be032703f6dca7d5ef82e725c798113b375000a"
dependencies = [ dependencies = [
"indexmap", "indexmap 2.0.0",
"serde", "serde",
"serde_json", "serde_json",
"utoipa-gen", "utoipa-gen",
@ -3496,21 +3542,22 @@ dependencies = [
[[package]] [[package]]
name = "utoipa-gen" name = "utoipa-gen"
version = "3.3.0" version = "3.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b" checksum = "6e22e88a487b6e0374533871b79b1f5ded05671bd0936bd547eb42f82fb9060d"
dependencies = [ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "regex",
"syn 2.0.25",
] ]
[[package]] [[package]]
name = "utoipa-swagger-ui" name = "utoipa-swagger-ui"
version = "3.1.3" version = "3.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "062bba5a3568e126ac72049a63254f4cb1da2eb713db0c1ab2a4c76be191db8c" checksum = "4602d7100d3cfd8a086f30494e68532402ab662fa366c9d201d677e33cee138d"
dependencies = [ dependencies = [
"axum", "axum",
"mime_guess", "mime_guess",
@ -3536,9 +3583,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "vergen" name = "vergen"
version = "8.2.1" version = "8.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b3c89c2c7e50f33e4d35527e5bf9c11d6d132226dbbd1753f0fbe9f19ef88c6" checksum = "bbc5ad0d9d26b2c49a5ab7da76c3e79d3ee37e7821799f8223fcb8f2f391a2e7"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"rustc_version", "rustc_version",
@ -3599,7 +3646,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3633,7 +3680,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.22", "syn 2.0.25",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3706,21 +3753,6 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.45.0" version = "0.45.0"

View File

@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.9.1" version = "0.9.3"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention # Build specific version of flash attention
RUN make build-flash-attention RUN make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

View File

@ -1,5 +1,7 @@
<div align="center"> <div align="center">
![image](https://github.com/huggingface/text-generation-inference/assets/3841370/38ba1531-ea0d-4851-b31a-a6d4ddc944b0)
# Text Generation Inference # Text Generation Inference
<a href="https://github.com/huggingface/text-generation-inference"> <a href="https://github.com/huggingface/text-generation-inference">
@ -11,9 +13,6 @@
<a href="https://huggingface.github.io/text-generation-inference"> <a href="https://huggingface.github.io/text-generation-inference">
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational"> <img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
</a> </a>
![architecture](assets/architecture.jpg)
</div> </div>
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
@ -64,6 +63,8 @@ to power LLMs api-inference widgets.
- [Starcoder](https://huggingface.co/bigcode/starcoder) - [Starcoder](https://huggingface.co/bigcode/starcoder)
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b) - [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama)
Other architectures are supported on a best effort basis using: Other architectures are supported on a best effort basis using:
@ -133,6 +134,10 @@ print(text)
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Using on private models or gated models
You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources.
### Distributed Tracing ### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature `text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
@ -212,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y
### CUDA Kernels ### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
the kernels by using the `BUILD_EXTENSIONS=False` environment variable. the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
Be aware that the official Docker image has them enabled by default. Be aware that the official Docker image has them enabled by default.

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.9.1" "version": "0.9.3"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -13,7 +13,7 @@ nix = "0.26.2"
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93" serde_json = "1.0.93"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] } tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"

View File

@ -4,10 +4,10 @@ use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Lines, Read};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError; use std::sync::mpsc::TryRecvError;
use std::sync::{mpsc, Arc}; use std::sync::{mpsc, Arc};
@ -15,6 +15,7 @@ use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{fs, io};
use tracing_subscriber::EnvFilter;
mod env_runtime; mod env_runtime;
@ -41,6 +42,7 @@ impl std::fmt::Display for Quantization {
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype { enum Dtype {
Float16, Float16,
#[clap(name = "bfloat16")]
BFloat16, BFloat16,
} }
@ -182,8 +184,8 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention /// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number /// or the model implementation, text-generation-inference cannot infer this number
/// automatically. /// automatically.
#[clap(default_value = "16000", long, env)] #[clap(long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: Option<u32>,
/// This setting defines how many tokens can be passed before forcing the waiting /// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it). /// queries to be put on the batch (if the size of the batch allows for it).
@ -265,17 +267,9 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
/// ngrok domain name where the axum webserver will be available at /// ngrok edge
#[clap(long, env)] #[clap(long, env)]
ngrok_domain: Option<String>, ngrok_edge: Option<String>,
/// ngrok basic auth username
#[clap(long, env)]
ngrok_username: Option<String>,
/// ngrok basic auth password
#[clap(long, env)]
ngrok_password: Option<String>,
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
@ -285,7 +279,7 @@ struct Args {
#[derive(Debug)] #[derive(Debug)]
enum ShardStatus { enum ShardStatus {
Ready, Ready,
Failed((usize, Option<String>)), Failed(usize),
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -310,6 +304,9 @@ fn shard_manager(
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
) { ) {
// Enter shard-manager tracing span
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
// Get UDS path // Get UDS path
let uds_string = format!("{uds_path}-{rank}"); let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string); let uds = Path::new(&uds_string);
@ -319,7 +316,7 @@ fn shard_manager(
} }
// Process args // Process args
let mut shard_argv = vec![ let mut shard_args = vec![
"serve".to_string(), "serve".to_string(),
model_id, model_id,
"--uds-path".to_string(), "--uds-path".to_string(),
@ -331,77 +328,71 @@ fn shard_manager(
// Activate trust remote code // Activate trust remote code
if trust_remote_code { if trust_remote_code {
shard_argv.push("--trust-remote-code".to_string()); shard_args.push("--trust-remote-code".to_string());
} }
// Activate tensor parallelism // Activate tensor parallelism
if world_size > 1 { if world_size > 1 {
shard_argv.push("--sharded".to_string()); shard_args.push("--sharded".to_string());
} }
if let Some(quantize) = quantize { if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string()); shard_args.push("--quantize".to_string());
shard_argv.push(quantize.to_string()) shard_args.push(quantize.to_string())
} }
if let Some(dtype) = dtype { if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string()); shard_args.push("--dtype".to_string());
shard_argv.push(dtype.to_string()) shard_args.push(dtype.to_string())
} }
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_args.push("--revision".to_string());
shard_argv.push(revision) shard_args.push(revision)
} }
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint { if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string()); shard_args.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint); shard_args.push(otlp_endpoint);
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars // Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into())); envs.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
env.push(("MASTER_ADDR".into(), master_addr.into())); envs.push(("MASTER_ADDR".into(), master_addr.into()));
env.push(("MASTER_PORT".into(), master_port.to_string().into())); envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
// Safetensors load fast // Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push(( envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(), "HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(), enable_hf_transfer.into(),
)); ));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// If weights_cache_override is some, pass it to the shard // If weights_cache_override is some, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = weights_cache_override { if let Some(weights_cache_override) = weights_cache_override {
env.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
)); ));
@ -409,24 +400,24 @@ fn shard_manager(
// If disable_custom_kernels is true, pass it to the shard as an env var // If disable_custom_kernels is true, pass it to the shard as an env var
if disable_custom_kernels { if disable_custom_kernels {
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
} }
// Watermark Gamma // Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma { if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
} }
// Watermark Delta // Watermark Delta
if let Some(watermark_delta) = watermark_delta { if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
} }
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server") let mut p = match Command::new("text-generation-server")
.args(shard_argv) .args(shard_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -437,30 +428,23 @@ fn shard_manager(
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} else { }
{
tracing::error!("{}", err); tracing::error!("{}", err);
} }
status_sender status_sender.send(ShardStatus::Failed(rank)).unwrap();
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
.unwrap();
return; return;
} }
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || { thread::spawn(move || {
// Enter shard-manager tracing span log_lines(shard_stdout_reader.lines());
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in shard_stdout_reader.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
}); });
let mut ready = false; let mut ready = false;
@ -469,30 +453,25 @@ fn shard_manager(
loop { loop {
// Process exited // Process exited
if let Some(exit_status) = p.try_wait().unwrap() { if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that `read_to_string` can block // We read stderr in another thread as it seems that lines() can block in some cases
// indefinitely in some cases
let (err_sender, err_receiver) = mpsc::channel(); let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || { thread::spawn(move || {
let mut err = String::new(); for line in shard_stderr_reader.lines().flatten() {
shard_stderr_reader.read_to_string(&mut err).unwrap(); err_sender.send(line).unwrap_or(());
err_sender.send(err).unwrap_or(()); }
}); });
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
let err = err_receiver tracing::error!("Shard complete standard error output:\n{err}");
.recv_timeout(Duration::from_millis(100))
.map_err(|err| {
tracing::error!("Unable to read shard {rank} error from stderr");
err
})
.ok();
if let Some(signal) = exit_status.signal() { if let Some(signal) = exit_status.signal() {
tracing::error!("Shard process was signaled to shutdown with signal {signal}"); tracing::error!("Shard process was signaled to shutdown with signal {signal}");
} }
status_sender status_sender.send(ShardStatus::Failed(rank)).unwrap();
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return; return;
} }
@ -500,17 +479,17 @@ fn shard_manager(
if shutdown.load(Ordering::SeqCst) { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); p.kill().unwrap();
let _ = p.wait(); let _ = p.wait();
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard terminated");
return; return;
} }
// Shard is ready // Shard is ready
if uds.exists() && !ready { if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap(); status_sender.send(ShardStatus::Ready).unwrap();
ready = true; ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) { } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready..."); tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now(); wait_time = Instant::now();
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
@ -579,6 +558,23 @@ impl PythonLogMessage {
} }
} }
impl TryFrom<&String> for PythonLogMessage {
type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value)
}
}
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
for line in lines.flatten() {
match PythonLogMessage::try_from(&line) {
Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"),
}
}
}
fn find_num_shards( fn find_num_shards(
sharded: Option<bool>, sharded: Option<bool>,
num_shard: Option<usize>, num_shard: Option<usize>,
@ -632,7 +628,10 @@ enum LauncherError {
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
let mut download_argv = vec![ // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), args.model_id.to_string(),
"--extension".to_string(), "--extension".to_string(),
@ -644,35 +643,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &args.revision {
download_argv.push("--revision".to_string()); download_args.push("--revision".to_string());
download_argv.push(revision.to_string()) download_args.push(revision.to_string())
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// If huggingface_hub_cache is set, pass it to the download process // If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// Enable hf transfer for insane download speeds // Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
env.push(( envs.push((
"HF_HUB_ENABLE_HF_TRANSFER".into(), "HF_HUB_ENABLE_HF_TRANSFER".into(),
enable_hf_transfer.into(), enable_hf_transfer.into(),
)); ));
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
// If args.weights_cache_override is some, pass it to the download process // If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override { if let Some(weights_cache_override) = &args.weights_cache_override {
env.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
)); ));
@ -681,8 +680,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Command::new("text-generation-server") let mut download_process = match Command::new("text-generation-server")
.args(download_argv) .args(download_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -693,6 +692,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} else {
tracing::error!("{}", err);
} }
return Err(LauncherError::DownloadError); return Err(LauncherError::DownloadError);
@ -701,16 +702,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// Redirect STDOUT to the console // Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap(); let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
thread::spawn(move || { thread::spawn(move || {
// Enter download tracing span log_lines(stdout.lines());
let stdout = BufReader::new(download_stdout);
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
}); });
loop { loop {
@ -738,10 +733,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
return Err(LauncherError::DownloadError); return Err(LauncherError::DownloadError);
} }
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap(); terminate("download", download_process, Duration::from_secs(10)).unwrap();
tracing::info!("Waiting for download process to gracefully shutdown");
download_process.wait().unwrap();
tracing::info!("Download process terminated");
return Ok(()); return Ok(());
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
@ -760,16 +752,6 @@ fn spawn_shards(
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>, running: Arc<AtomicBool>,
) -> Result<(), LauncherError> { ) -> Result<(), LauncherError> {
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
if args.revision.is_none() {
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
}
}
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_id = args.model_id.clone(); let model_id = args.model_id.clone();
@ -828,11 +810,8 @@ fn spawn_shards(
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {rank} failed to start"); tracing::error!("Shard {rank} failed to start");
if let Some(err) = err {
tracing::error!("{err}");
}
shutdown_shards(shutdown, shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart); return Err(LauncherError::ShardCannotStart);
} }
@ -854,7 +833,7 @@ fn spawn_webserver(
// All shard started // All shard started
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut argv = vec![ let mut router_args = vec![
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),
@ -867,8 +846,6 @@ fn spawn_webserver(
args.max_total_tokens.to_string(), args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(), "--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(), args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -885,63 +862,54 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
router_args.push(max_batch_total_tokens.to_string());
}
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
argv.push("--revision".to_string()); router_args.push("--revision".to_string());
argv.push(revision.to_string()) router_args.push(revision.to_string())
} }
if args.json_output { if args.json_output {
argv.push("--json-output".to_string()); router_args.push("--json-output".to_string());
} }
// OpenTelemetry // OpenTelemetry
if let Some(otlp_endpoint) = args.otlp_endpoint { if let Some(otlp_endpoint) = args.otlp_endpoint {
argv.push("--otlp-endpoint".to_string()); router_args.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint); router_args.push(otlp_endpoint);
} }
// CORS origins // CORS origins
for origin in args.cors_allow_origin.into_iter() { for origin in args.cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string()); router_args.push("--cors-allow-origin".to_string());
argv.push(origin); router_args.push(origin);
} }
// Ngrok // Ngrok
if args.ngrok { if args.ngrok {
let authtoken = args.ngrok_authtoken.ok_or_else(|| { router_args.push("--ngrok".to_string());
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling"); router_args.push("--ngrok-authtoken".to_string());
LauncherError::WebserverCannotStart router_args.push(args.ngrok_authtoken.unwrap());
})?; router_args.push("--ngrok-edge".to_string());
router_args.push(args.ngrok_edge.unwrap());
argv.push("--ngrok".to_string());
argv.push("--ngrok-authtoken".to_string());
argv.push(authtoken);
if let Some(domain) = args.ngrok_domain {
argv.push("--ngrok-domain".to_string());
argv.push(domain);
}
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
argv.push("--ngrok-username".to_string());
argv.push(username);
argv.push("--ngrok-password".to_string());
argv.push(password);
}
} }
// Copy current process env // Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Parse Inference API token // Parse Inference API token
if let Ok(api_token) = env::var("HF_API_TOKEN") { if let Ok(api_token) = env::var("HF_API_TOKEN") {
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
}; };
let mut webserver = match Command::new("text-generation-router") let mut webserver = match Command::new("text-generation-router")
.args(argv) .args(router_args)
.envs(env) .envs(envs)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -979,14 +947,49 @@ fn spawn_webserver(
Ok(webserver) Ok(webserver)
} }
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
tracing::info!("Terminating {process_name}");
let terminate_time = Instant::now();
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
tracing::info!("Waiting for {process_name} to gracefully shutdown");
while terminate_time.elapsed() < timeout {
if let Some(status) = process.try_wait()? {
tracing::info!("{process_name} terminated");
return Ok(status);
}
sleep(Duration::from_millis(100));
}
tracing::info!("Killing {process_name}");
process.kill()?;
let exit_status = process.wait()?;
tracing::info!("{process_name} killed");
Ok(exit_status)
}
fn main() -> Result<(), LauncherError> { fn main() -> Result<(), LauncherError> {
// Pattern match configuration // Pattern match configuration
let args = Args::parse(); let args: Args = Args::parse();
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
if args.json_output { if args.json_output {
tracing_subscriber::fmt().json().init(); tracing_subscriber::fmt()
.with_env_filter(env_filter)
.json()
.init();
} else { } else {
tracing_subscriber::fmt().compact().init(); tracing_subscriber::fmt()
.with_env_filter(env_filter)
.compact()
.init();
} }
if args.env { if args.env {
@ -1008,29 +1011,53 @@ fn main() -> Result<(), LauncherError> {
args.max_batch_prefill_tokens, args.max_input_length args.max_batch_prefill_tokens, args.max_input_length
))); )));
} }
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}
if args.validation_workers == 0 { if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation( return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
}
let num_shard = find_num_shards(args.sharded, args.num_shard)?; let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 { if num_shard > 1 {
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
)));
}
}
if args.ngrok {
if args.ngrok_authtoken.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
));
}
if args.ngrok_edge.is_none() {
return Err(LauncherError::ArgumentValidation(
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
));
}
}
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let r = running.clone(); let r = running.clone();
@ -1042,6 +1069,11 @@ fn main() -> Result<(), LauncherError> {
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(&args, running.clone())?;
if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
return Ok(());
}
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
// Shared shutdown channel // Shared shutdown channel
@ -1078,11 +1110,8 @@ fn main() -> Result<(), LauncherError> {
let mut exit_code = Ok(()); let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} crashed"); tracing::error!("Shard {rank} crashed");
if let Some(err) = err {
tracing::error!("{err}");
}
exit_code = Err(LauncherError::ShardFailed); exit_code = Err(LauncherError::ShardFailed);
break; break;
}; };
@ -1100,10 +1129,7 @@ fn main() -> Result<(), LauncherError> {
} }
// Graceful termination // Graceful termination
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap(); terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait().unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
exit_code exit_code

View File

@ -198,9 +198,10 @@ message DecodeResponse {
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
} }
/// Empty response /// Empty response
message WarmupResponse {} message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -103,8 +103,7 @@ impl Client {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, ) -> Result<Option<u32>> {
) -> Result<()> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
@ -143,13 +142,9 @@ impl Client {
max_tokens: 0, max_tokens: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
batch: Some(batch), let response = self.stub.warmup(request).await?.into_inner();
max_total_tokens, Ok(response.max_supported_total_tokens)
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -95,14 +95,11 @@ impl ShardedClient {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, ) -> Result<Option<u32>> {
) -> Result<()> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()

View File

@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding); let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -37,8 +37,8 @@ struct Args {
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)] #[clap(long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
@ -49,8 +49,8 @@ struct Args {
master_shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "main", long, env)] #[clap(long, env)]
revision: String, revision: Option<String>,
#[clap(default_value = "2", long, env)] #[clap(default_value = "2", long, env)]
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
@ -64,11 +64,7 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
#[clap(long, env)] #[clap(long, env)]
ngrok_domain: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env)]
ngrok_username: Option<String>,
#[clap(long, env)]
ngrok_password: Option<String>,
} }
fn main() -> Result<(), RouterError> { fn main() -> Result<(), RouterError> {
@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> {
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_domain, ngrok_edge,
ngrok_username,
ngrok_password,
} = args; } = args;
// Validate args // Validate args
@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> {
if max_input_length as u32 > max_batch_prefill_tokens { if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
} }
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue // map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin // Finally, convert to AllowOrigin
@ -147,7 +145,7 @@ fn main() -> Result<(), RouterError> {
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters { let params = FromPretrainedParameters {
revision: revision.clone(), revision: revision.clone().unwrap_or("main".to_string()),
auth_token: authorization_token.clone(), auth_token: authorization_token.clone(),
..Default::default() ..Default::default()
}; };
@ -175,7 +173,7 @@ fn main() -> Result<(), RouterError> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token) false => get_model_info(&tokenizer_name, revision, authorization_token)
.await .await
.unwrap_or_else(|| { .unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup( .warmup(max_input_length as u32, max_batch_prefill_tokens)
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await .await
.map_err(RouterError::Warmup)?; .map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected"); tracing::info!("Connected");
let addr = match hostname.parse() { let addr = match hostname.parse() {
@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> {
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_supported_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,
@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> {
cors_allow_origin, cors_allow_origin,
ngrok, ngrok,
ngrok_authtoken, ngrok_authtoken,
ngrok_domain, ngrok_edge,
ngrok_username,
ngrok_password,
) )
.await?; .await?;
Ok(()) Ok(())
@ -316,9 +333,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info( pub async fn get_model_info(
model_id: &str, model_id: &str,
revision: &str, revision: Option<String>,
token: Option<String>, token: Option<String>,
) -> Option<HubModelInfo> { ) -> Option<HubModelInfo> {
let revision = match revision {
None => {
tracing::warn!("`--revision` is not set");
tracing::warn!("We strongly advise to set it to a known supported commit.");
"main".to_string()
}
Some(revision) => revision,
};
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Poor man's urlencode // Poor man's urlencode
let revision = revision.replace('/', "%2F"); let revision = revision.replace('/', "%2F");
@ -331,9 +357,18 @@ pub async fn get_model_info(
let response = builder.send().await.ok()?; let response = builder.send().await.ok()?;
if response.status().is_success() { if response.status().is_success() {
return serde_json::from_str(&response.text().await.ok()?).ok(); let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
} }
None
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]

View File

@ -33,12 +33,12 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool) -> Self { pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(requires_padding, queue_receiver)); tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
Self { queue_sender } Self { queue_sender }
} }
@ -81,8 +81,12 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) { async fn queue_task(
let mut state = State::new(requires_padding); requires_padding: bool,
block_size: u32,
receiver: flume::Receiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
@ -119,15 +123,19 @@ struct State {
/// Whether the model is using padding /// Whether the model is using padding
requires_padding: bool, requires_padding: bool,
/// Paged Attention block size
block_size: u32,
} }
impl State { impl State {
fn new(requires_padding: bool) -> Self { fn new(requires_padding: bool, block_size: u32) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
block_size,
} }
} }
@ -187,10 +195,21 @@ impl State {
max_input_length = max_input_length.max(entry.request.input_length); max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
prefill_tokens += entry.request.input_length; // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
} }
decode_tokens += entry.request.stopping_parameters.max_new_tokens; if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
// pad to block size
decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget || (prefill_tokens + decode_tokens) > token_budget
@ -321,7 +340,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -337,7 +356,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false); let mut state = State::new(false, 1);
assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -345,7 +364,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -377,7 +396,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -410,14 +429,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -425,7 +444,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -458,7 +477,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -483,7 +502,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -524,9 +524,7 @@ pub async fn run(
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
ngrok: bool, ngrok: bool,
ngrok_authtoken: Option<String>, ngrok_authtoken: Option<String>,
ngrok_domain: Option<String>, ngrok_edge: Option<String>,
ngrok_username: Option<String>,
ngrok_password: Option<String>,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -696,32 +694,25 @@ pub async fn run(
#[cfg(feature = "ngrok")] #[cfg(feature = "ngrok")]
{ {
use ngrok::config::TunnelBuilder; use ngrok::config::TunnelBuilder;
use ngrok::tunnel::UrlTunnel;
let _ = addr; let _ = addr;
let authtoken = let authtoken =
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
let mut tunnel = ngrok::Session::builder() let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
let tunnel = ngrok::Session::builder()
.authtoken(authtoken) .authtoken(authtoken)
.connect() .connect()
.await .await
.unwrap() .unwrap()
.http_endpoint(); .labeled_tunnel()
.label("edge", edge);
if let Some(domain) = ngrok_domain {
tunnel = tunnel.domain(domain);
}
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
tunnel = tunnel.basic_auth(username, password);
}
let listener = tunnel.listen().await.unwrap(); let listener = tunnel.listen().await.unwrap();
// Run server // Run server
tracing::info!("Ingress URL: {:?}", listener.url());
axum::Server::builder(listener) axum::Server::builder(listener)
.serve(app.into_make_service()) .serve(app.into_make_service())
//Wait until all requests are finished to shut down //Wait until all requests are finished to shut down

View File

@ -1,4 +1,5 @@
include Makefile-flash-att include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
unit-tests: unit-tests:

View File

@ -0,0 +1,13 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build
install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "0.9.1" version = "0.9.3"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -194,6 +194,8 @@ def quantize(
percdamp: float = 0.01, percdamp: float = 0.01,
act_order: bool = False, act_order: bool = False,
): ):
if revision is None:
revision = "main"
download_weights( download_weights(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
@ -207,6 +209,7 @@ def quantize(
bits=4, bits=4,
groupsize=128, groupsize=128,
output_dir=output_dir, output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id, upload_to_model_id=upload_to_model_id,
percdamp=percdamp, percdamp=percdamp,

View File

@ -42,51 +42,21 @@ __all__ = [
"get_model", "get_model",
] ]
FLASH_ATT_ERROR_MESSAGE = ( FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATTENTION = True
try: try:
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": from text_generation_server.models.flash_rw import FlashRWSharded
if not torch.cuda.is_available(): from text_generation_server.models.flash_neox import FlashNeoXSharded
FLASH_ATT_ERROR_MESSAGE = ( from text_generation_server.models.flash_llama import (
"{} requires CUDA. No compatible CUDA devices found." FlashLlama,
)
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
) )
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -23,25 +23,77 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear,
) )
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6): def __init__(self, prefix, weights, eps=1e-6):
""" """
@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module):
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res return normed_hidden_states, res
def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -118,22 +192,29 @@ class FlashLlamaAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi( self.num_key_value_heads = (
config, config.num_key_value_heads // weights.process_group.size()
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
) )
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
) ).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# Inplace rotary self.rotary_emb(query, cos, sin)
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
qkv[:, 0], query,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:
@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_size = kv_cache[1].shape[3] block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention( vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
qkv[:, 0], query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
@ -323,6 +403,7 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,

View File

@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -5,13 +5,11 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -265,26 +263,15 @@ class FlashMQAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -712,14 +712,14 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): def warmup(self, batch: FlashCausalLMBatch):
global CACHE_MANAGER global CACHE_MANAGER
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
try: try:
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
# Adds some wiggle room batch.blocks,
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
@ -729,11 +729,43 @@ class FlashCausalLM(Model):
_, batch = self.generate_token(batch) _, batch = self.generate_token(batch)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`"
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
) from e ) from e
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# 0.98 to add some wiggle room
num_blocks = (
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks
)
del CACHE_MANAGER
del batch del batch
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
return int(num_blocks * BLOCK_SIZE)
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(

View File

@ -2,13 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = AutoConfig.from_pretrained( config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
@ -70,7 +70,7 @@ class FlashLlama(FlashCausalLM):
tokenizer=tokenizer, tokenizer=tokenizer,
config=config, config=config,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads, num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,

View File

@ -107,8 +107,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int): def warmup(self, batch: B) -> Optional[int]:
self.generate_token(batch) self.generate_token(batch)
return None
def decode_token( def decode_token(
self, self,

View File

@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids) filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
self.model.warmup(batch, request.max_total_tokens) max_supported_total_tokens = self.model.warmup(batch)
if torch.cuda.is_available(): return generate_pb2.WarmupResponse(
torch.cuda.empty_cache() max_supported_total_tokens=max_supported_total_tokens
)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else: else:
batch = batches[0] batch = batches[0]

View File

@ -94,6 +94,14 @@ def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: Lis
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
# Skip blacklisted files
if (
"arguments" in pt_file.name
or "args" in pt_file.name
or "training" in pt_file.name
):
continue
start = datetime.datetime.now() start = datetime.datetime.now()
convert_file(pt_file, sf_file, discard_names) convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start elapsed = datetime.datetime.now() - start

View File

@ -0,0 +1,124 @@
import os
import torch
from loguru import logger
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False
try:
try:
import flash_attn_2_cuda
except ImportError:
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2 = True
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
):
if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN:
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
raise NotImplementedError("flash attention is not installed")

View File

@ -13,6 +13,9 @@ import transformers
from huggingface_hub import HfApi from huggingface_hub import HfApi
import numpy as np import numpy as np
import torch import torch
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
maxshrink=0.8, maxshrink=0.8,
trits=False, trits=False,
): ):
self.maxq = torch.tensor(2**bits - 1) self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel self.perchannel = perchannel
self.sym = sym self.sym = sym
@ -600,6 +602,8 @@ def sequential(
nsamples, nsamples,
bits, bits,
groupsize, groupsize,
*,
hooks,
percdamp=0.01, percdamp=0.01,
sym: bool = False, sym: bool = False,
act_order: bool = False, act_order: bool = False,
@ -637,7 +641,7 @@ def sequential(
layers[0] = Catcher(layers[0]) layers[0] = Catcher(layers[0])
for batch in dataloader: for batch in dataloader:
try: try:
model(batch[0]) model(batch[0].cuda())
except ValueError: except ValueError:
pass pass
layers[0] = layers[0].module layers[0] = layers[0].module
@ -646,6 +650,8 @@ def sequential(
# model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu() # model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
for hook in hooks:
hook.remove()
outs = torch.zeros_like(inps) outs = torch.zeros_like(inps)
@ -662,10 +668,8 @@ def sequential(
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+") print("+==================+==============+============+===========+=======+")
from accelerate.hooks import remove_hook_from_submodules layer = layers[i]
layer.load()
layer = layers[i].to(dev)
remove_hook_from_submodules(layer)
full = find_layers(layer) full = find_layers(layer)
sequential = [list(full.keys())] sequential = [list(full.keys())]
@ -677,6 +681,7 @@ def sequential(
gptq[name].quantizer.configure( gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False bits, perchannel=True, sym=sym, mse=False
) )
pass
def add_batch(name): def add_batch(name):
def tmp(_, inp, out): def tmp(_, inp, out):
@ -688,7 +693,6 @@ def sequential(
for name in subset: for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name))) handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles: for h in handles:
h.remove() h.remove()
@ -714,7 +718,7 @@ def sequential(
for j in range(nsamples): for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layers[i] = layer.cpu() layer.unload()
del layer del layer
del gptq del gptq
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
return model return model
def setdeepattr(module, full_name, tensor):
current = module
tokens = full_name.split(".")
for token in tokens[:-1]:
current = getattr(current, token)
setattr(current, tokens[-1], tensor)
def getdeepattr(module, full_name):
current = module
tokens = full_name.split(".")
for token in tokens:
current = getattr(current, token)
return current
def load_weights_pre_hook(module_name, weights, recursive=False):
def inner(module, args):
print(f"Pre hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
current_tensor = getdeepattr(module, local_param)
if current_tensor.device == torch.device("meta"):
# print(f"Loading {local_param}")
if module_name:
tensor_name = f"{module_name}.{local_param}"
else:
tensor_name = local_param
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
)
return inner
def load_weights_post_hook(module_name, weights, recursive=False):
def inner(module, args, output):
print(f"Post hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
# print(f"Unloading {local_param}")
current_tensor = getdeepattr(module, local_param)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
)
return output
return inner
def quantize( def quantize(
model_id: str, model_id: str,
bits: int, bits: int,
groupsize: int, groupsize: int,
output_dir: str, output_dir: str,
revision: str,
trust_remote_code: bool, trust_remote_code: bool,
upload_to_model_id: Optional[str], upload_to_model_id: Optional[str],
percdamp: float, percdamp: float,
act_order: bool, act_order: bool,
): ):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
torch_dtype=torch.float16,
device_map="balanced_low_0",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
model = model.eval()
print("LOADED model") print("LOADED model")
files = weight_files(model_id, revision, extension=".safetensors")
process_group, _, _ = initialize_torch_distributed()
weights = Weights(
files,
device=torch.device("cuda:0"),
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
)
hooks = []
for name, module in model.named_modules():
def load(module, name):
def _load():
load_weights_pre_hook(name, weights, recursive=True)(module, None)
return _load
def unload(module, name):
def _unload():
load_weights_post_hook(name, weights, recursive=True)(
module, None, None
)
return _unload
module.load = load(module, name)
module.unload = unload(module, name)
hooks.append(
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
)
hooks.append(
module.register_forward_hook(load_weights_post_hook(name, weights))
)
model.seqlen = 2048 model.seqlen = 2048
dataset = "wikitext2" dataset = "wikitext2"
@ -806,6 +922,7 @@ def quantize(
groupsize, groupsize,
percdamp=percdamp, percdamp=percdamp,
act_order=act_order, act_order=act_order,
hooks=hooks,
) )
print(time.time() - tick) print(time.time() - tick)
@ -858,7 +975,6 @@ def quantize(
logger.info("Saved tokenizer") logger.info("Saved tokenizer")
if upload_to_model_id: if upload_to_model_id:
api = HfApi() api = HfApi()
api.upload_folder( api.upload_folder(