mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 20:12:07 +00:00
Merge branch 'main' into gptq-cuda-kernels
This commit is contained in:
commit
edfbfdfb3f
302
Cargo.lock
generated
302
Cargo.lock
generated
@ -148,18 +148,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
version = "0.1.71"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
|
||||
checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -410,9 +410,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.3.10"
|
||||
version = "4.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "384e169cc618c613d5e3ca6404dda77a8685a63e08660dcc64abaf7da7cb0c7a"
|
||||
checksum = "1640e5cc7fb47dbb8338fd471b105e7ed6c3cb2aeb00c2e067127ffd3764a05d"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@ -421,9 +421,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.3.10"
|
||||
version = "4.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef137bbe35aab78bdb468ccfba75a5f4d8321ae011d34063770780545176af2d"
|
||||
checksum = "98c59138d527eeaf9b53f35a77fcc1fad9d883116070c63d5de1c7dc7b00c72b"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@ -440,7 +440,7 @@ dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -492,9 +492,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.8"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c"
|
||||
checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
@ -633,12 +633,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.4.0"
|
||||
version = "5.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
|
||||
checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown 0.12.3",
|
||||
"hashbrown 0.14.0",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
@ -736,6 +736,12 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.1"
|
||||
@ -924,7 +930,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1023,7 +1029,7 @@ dependencies = [
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
@ -1038,13 +1044,19 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.13.2"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
|
||||
checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
@ -1053,9 +1065,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||
checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
@ -1190,6 +1202,16 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.0",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@ -1254,12 +1276,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.8"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb"
|
||||
checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"rustix 0.38.1",
|
||||
"rustix 0.38.4",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
@ -1292,9 +1314,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.6"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||
checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
@ -1397,7 +1419,7 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1432,9 +1454,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.21.0"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa8ebbd1a9e57bbab77b9facae7f5136aea44c356943bf9a198f647da64285d6"
|
||||
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"metrics-macros",
|
||||
@ -1449,7 +1471,7 @@ checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"hyper",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"ipnet",
|
||||
"metrics",
|
||||
"metrics-util",
|
||||
@ -1467,18 +1489,18 @@ checksum = "ddece26afd34c31585c74a4db0630c376df271c285d682d1e55012197830b6df"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-util"
|
||||
version = "0.15.0"
|
||||
version = "0.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "111cb375987443c3de8d503580b536f77dc8416d32db62d9456db5d93bd7ac47"
|
||||
checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.13.2",
|
||||
"hashbrown 0.13.1",
|
||||
"metrics",
|
||||
"num_cpus",
|
||||
"quanta",
|
||||
@ -1530,9 +1552,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "monostate"
|
||||
version = "0.1.6"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
|
||||
checksum = "3f3f57a8802842f648026a33c3d2e3bb41bb309a35b1609bd7ef2b060b8b6b1b"
|
||||
dependencies = [
|
||||
"monostate-impl",
|
||||
"serde",
|
||||
@ -1540,13 +1562,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "monostate-impl"
|
||||
version = "0.1.6"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
|
||||
checksum = "e72f4d2e10fde62a0f2fcb4b44ccbf4f9899dcc30c9193449f8dfb9123d71377"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1701,6 +1723,15 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_threads"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.3.0"
|
||||
@ -1773,7 +1804,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1854,7 +1885,7 @@ dependencies = [
|
||||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
@ -1870,7 +1901,7 @@ dependencies = [
|
||||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"thiserror",
|
||||
@ -1974,9 +2005,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.12"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
|
||||
checksum = "b4b27ab7be369122c218afc2079489cdcb4b517c0a3fc386ff11e1fedfcc2b35"
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
@ -2003,34 +2034,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4"
|
||||
dependencies = [
|
||||
"fixedbitset",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.1"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e138fdd8263907a2b0e1b4e80b7e58c721126479b6e6eedfb1b402acea7b9bd"
|
||||
checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.1.1"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1fef411b303e3e12d534fb6e7852de82da56edd937d895125821fb7c09436c7"
|
||||
checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
|
||||
checksum = "4c40d25201921e5ff0c862a505c6557ea88568a4e3ace775ab55e93f2f4f9d57"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
@ -2046,9 +2077,9 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.3.3"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "767eb9f07d4a5ebcb39bbf2d452058a93c011373abf6832e24194a1c3f004794"
|
||||
checksum = "d220334a184db82b31b83f5ff093e3315280fb2b6bbc032022b2304a509aab7a"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
@ -2092,9 +2123,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.63"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb"
|
||||
checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
@ -2294,13 +2325,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.8.4"
|
||||
version = "1.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
|
||||
checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.2",
|
||||
"memchr",
|
||||
"regex-syntax 0.7.2",
|
||||
"regex-automata 0.3.3",
|
||||
"regex-syntax 0.7.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2312,6 +2344,17 @@ dependencies = [
|
||||
"regex-syntax 0.6.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.2",
|
||||
"memchr",
|
||||
"regex-syntax 0.7.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
@ -2320,9 +2363,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.2"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78"
|
||||
checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
@ -2397,7 +2440,7 @@ dependencies = [
|
||||
"quote",
|
||||
"rust-embed-utils",
|
||||
"shellexpand",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
@ -2428,9 +2471,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.37.21"
|
||||
version = "0.37.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c"
|
||||
checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"errno",
|
||||
@ -2442,9 +2485,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.1"
|
||||
version = "0.38.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3"
|
||||
checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5"
|
||||
dependencies = [
|
||||
"bitflags 2.3.3",
|
||||
"errno",
|
||||
@ -2476,15 +2519,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.12"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06"
|
||||
checksum = "dc31bd9b61a32c31f9650d18add92aa83a49ba979c143eefd27fe7177b05bd5f"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.13"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
checksum = "fe232bdf6be8c8de797b22184ee71118d63780ea42ac85b61d1baa6d3b782ae9"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
@ -2497,11 +2540,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.21"
|
||||
version = "0.1.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3"
|
||||
checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88"
|
||||
dependencies = [
|
||||
"windows-sys 0.42.0",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2551,29 +2594,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.164"
|
||||
version = "1.0.171"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d"
|
||||
checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.164"
|
||||
version = "1.0.171"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68"
|
||||
checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.99"
|
||||
version = "1.0.102"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3"
|
||||
checksum = "b5062a995d481b2308b6064e9af76011f2921c35f97b0468811ed9f6cd91dfed"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
@ -2582,10 +2625,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.11"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0"
|
||||
checksum = "8acc4422959dd87a76cb117c191dcbffc20467f06c9100b76721dab370f24d3a"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@ -2697,9 +2741,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.10.0"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
@ -2769,9 +2813,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.22"
|
||||
version = "2.0.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616"
|
||||
checksum = "15e3fc8c0c74267e2df136e5e5fb656a464158aa57624053375eb9c8c6e25ae2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -2786,9 +2830,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "sysinfo"
|
||||
version = "0.29.3"
|
||||
version = "0.29.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5bcd0346f90b6bc83526c7b180039a8acd26a5c848cc556d457f6472eb148122"
|
||||
checksum = "751e810399bba86e9326f5762b7f32ac5a085542df78da6a78d94e07d14d7c11"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"core-foundation-sys",
|
||||
@ -2824,9 +2868,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.38"
|
||||
version = "0.4.39"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6"
|
||||
checksum = "ec96d2ffad078296368d46ff1cb309be1c23c513b4ab0e22a45de0185275ac96"
|
||||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
@ -2843,13 +2887,13 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"fastrand",
|
||||
"redox_syscall 0.3.5",
|
||||
"rustix 0.37.21",
|
||||
"rustix 0.37.23",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
@ -2869,7 +2913,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
@ -2885,7 +2929,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
@ -2901,7 +2945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
@ -2934,22 +2978,22 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.40"
|
||||
version = "1.0.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac"
|
||||
checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.40"
|
||||
version = "1.0.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
|
||||
checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2964,11 +3008,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.22"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd"
|
||||
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"libc",
|
||||
"num_threads",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
@ -2982,9 +3028,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.9"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b"
|
||||
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
|
||||
dependencies = [
|
||||
"time-core",
|
||||
]
|
||||
@ -3078,7 +3124,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3209,7 +3255,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"pin-project",
|
||||
"pin-project-lite",
|
||||
"rand",
|
||||
@ -3291,7 +3337,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3413,9 +3459,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.9"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0"
|
||||
checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
@ -3484,11 +3530,11 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||
|
||||
[[package]]
|
||||
name = "utoipa"
|
||||
version = "3.3.0"
|
||||
version = "3.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
|
||||
checksum = "520434cac5c98120177d5cc15be032703f6dca7d5ef82e725c798113b375000a"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"indexmap 2.0.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa-gen",
|
||||
@ -3496,21 +3542,22 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-gen"
|
||||
version = "3.3.0"
|
||||
version = "3.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b"
|
||||
checksum = "6e22e88a487b6e0374533871b79b1f5ded05671bd0936bd547eb42f82fb9060d"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"regex",
|
||||
"syn 2.0.25",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-swagger-ui"
|
||||
version = "3.1.3"
|
||||
version = "3.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "062bba5a3568e126ac72049a63254f4cb1da2eb713db0c1ab2a4c76be191db8c"
|
||||
checksum = "4602d7100d3cfd8a086f30494e68532402ab662fa366c9d201d677e33cee138d"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"mime_guess",
|
||||
@ -3536,9 +3583,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "vergen"
|
||||
version = "8.2.1"
|
||||
version = "8.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b3c89c2c7e50f33e4d35527e5bf9c11d6d132226dbbd1753f0fbe9f19ef88c6"
|
||||
checksum = "bbc5ad0d9d26b2c49a5ab7da76c3e79d3ee37e7821799f8223fcb8f2f391a2e7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"rustc_version",
|
||||
@ -3599,7 +3646,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@ -3633,7 +3680,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.22",
|
||||
"syn 2.0.25",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@ -3706,21 +3753,6 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
|
@ -8,7 +8,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
15
Dockerfile
15
Dockerfile
@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
|
||||
# Build specific version of flash attention
|
||||
RUN make build-flash-attention
|
||||
|
||||
# Build Flash Attention v2 CUDA kernels
|
||||
FROM kernel-builder as flash-att-v2-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-flash-att-v2 Makefile
|
||||
|
||||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
|
||||
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
13
README.md
13
README.md
@ -1,5 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
# Text Generation Inference
|
||||
|
||||
<a href="https://github.com/huggingface/text-generation-inference">
|
||||
@ -11,9 +13,6 @@
|
||||
<a href="https://huggingface.github.io/text-generation-inference">
|
||||
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
||||
</a>
|
||||
|
||||

|
||||
|
||||
</div>
|
||||
|
||||
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
||||
@ -64,6 +63,8 @@ to power LLMs api-inference widgets.
|
||||
- [Starcoder](https://huggingface.co/bigcode/starcoder)
|
||||
- [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b)
|
||||
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||
- [Llama V2](https://huggingface.co/meta-llama)
|
||||
|
||||
Other architectures are supported on a best effort basis using:
|
||||
|
||||
@ -133,6 +134,10 @@ print(text)
|
||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
||||
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
||||
|
||||
### Using on private models or gated models
|
||||
|
||||
You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources.
|
||||
|
||||
### Distributed Tracing
|
||||
|
||||
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
|
||||
@ -212,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y
|
||||
### CUDA Kernels
|
||||
|
||||
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
||||
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
|
||||
the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable.
|
||||
|
||||
Be aware that the official Docker image has them enabled by default.
|
||||
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "0.9.1"
|
||||
"version": "0.9.3"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -13,7 +13,7 @@ nix = "0.26.2"
|
||||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = "1.0.93"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json"] }
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||
|
||||
[dev-dependencies]
|
||||
float_eq = "1.0.1"
|
||||
|
@ -4,10 +4,10 @@ use nix::unistd::Pid;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::io::{BufRead, BufReader, Read};
|
||||
use std::io::{BufRead, BufReader, Lines, Read};
|
||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||
use std::path::Path;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::TryRecvError;
|
||||
use std::sync::{mpsc, Arc};
|
||||
@ -15,6 +15,7 @@ use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
@ -41,6 +42,7 @@ impl std::fmt::Display for Quantization {
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Dtype {
|
||||
Float16,
|
||||
#[clap(name = "bfloat16")]
|
||||
BFloat16,
|
||||
}
|
||||
|
||||
@ -182,8 +184,8 @@ struct Args {
|
||||
/// depends on other parameters like if you're using quantization, flash attention
|
||||
/// or the model implementation, text-generation-inference cannot infer this number
|
||||
/// automatically.
|
||||
#[clap(default_value = "16000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
|
||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||
/// queries to be put on the batch (if the size of the batch allows for it).
|
||||
@ -265,17 +267,9 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
|
||||
/// ngrok domain name where the axum webserver will be available at
|
||||
/// ngrok edge
|
||||
#[clap(long, env)]
|
||||
ngrok_domain: Option<String>,
|
||||
|
||||
/// ngrok basic auth username
|
||||
#[clap(long, env)]
|
||||
ngrok_username: Option<String>,
|
||||
|
||||
/// ngrok basic auth password
|
||||
#[clap(long, env)]
|
||||
ngrok_password: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
|
||||
/// Display a lot of information about your runtime environment
|
||||
#[clap(long, short, action)]
|
||||
@ -285,7 +279,7 @@ struct Args {
|
||||
#[derive(Debug)]
|
||||
enum ShardStatus {
|
||||
Ready,
|
||||
Failed((usize, Option<String>)),
|
||||
Failed(usize),
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -310,6 +304,9 @@ fn shard_manager(
|
||||
shutdown: Arc<AtomicBool>,
|
||||
_shutdown_sender: mpsc::Sender<()>,
|
||||
) {
|
||||
// Enter shard-manager tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
|
||||
|
||||
// Get UDS path
|
||||
let uds_string = format!("{uds_path}-{rank}");
|
||||
let uds = Path::new(&uds_string);
|
||||
@ -319,7 +316,7 @@ fn shard_manager(
|
||||
}
|
||||
|
||||
// Process args
|
||||
let mut shard_argv = vec![
|
||||
let mut shard_args = vec![
|
||||
"serve".to_string(),
|
||||
model_id,
|
||||
"--uds-path".to_string(),
|
||||
@ -331,77 +328,71 @@ fn shard_manager(
|
||||
|
||||
// Activate trust remote code
|
||||
if trust_remote_code {
|
||||
shard_argv.push("--trust-remote-code".to_string());
|
||||
shard_args.push("--trust-remote-code".to_string());
|
||||
}
|
||||
|
||||
// Activate tensor parallelism
|
||||
if world_size > 1 {
|
||||
shard_argv.push("--sharded".to_string());
|
||||
shard_args.push("--sharded".to_string());
|
||||
}
|
||||
|
||||
if let Some(quantize) = quantize {
|
||||
shard_argv.push("--quantize".to_string());
|
||||
shard_argv.push(quantize.to_string())
|
||||
shard_args.push("--quantize".to_string());
|
||||
shard_args.push(quantize.to_string())
|
||||
}
|
||||
|
||||
if let Some(dtype) = dtype {
|
||||
shard_argv.push("--dtype".to_string());
|
||||
shard_argv.push(dtype.to_string())
|
||||
shard_args.push("--dtype".to_string());
|
||||
shard_args.push(dtype.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision {
|
||||
shard_argv.push("--revision".to_string());
|
||||
shard_argv.push(revision)
|
||||
shard_args.push("--revision".to_string());
|
||||
shard_args.push(revision)
|
||||
}
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
shard_argv.push("--otlp-endpoint".to_string());
|
||||
shard_argv.push(otlp_endpoint);
|
||||
shard_args.push("--otlp-endpoint".to_string());
|
||||
shard_args.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Use cuda allocator. It leads to less memory fragmentation
|
||||
env.push((
|
||||
"PYTORCH_CUDA_ALLOC_CONF".into(),
|
||||
"backend:cudaMallocAsync".into(),
|
||||
));
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Torch Distributed Env vars
|
||||
env.push(("RANK".into(), rank.to_string().into()));
|
||||
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||
env.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||
envs.push(("RANK".into(), rank.to_string().into()));
|
||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||
envs.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||
envs.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||
envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||
|
||||
// Safetensors load fast
|
||||
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
||||
env.push((
|
||||
envs.push((
|
||||
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
||||
enable_hf_transfer.into(),
|
||||
));
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// If weights_cache_override is some, pass it to the shard
|
||||
// Useful when running inside a HuggingFace Inference Endpoint
|
||||
if let Some(weights_cache_override) = weights_cache_override {
|
||||
env.push((
|
||||
envs.push((
|
||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||
weights_cache_override.into(),
|
||||
));
|
||||
@ -409,24 +400,24 @@ fn shard_manager(
|
||||
|
||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||
if disable_custom_kernels {
|
||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||
envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||
}
|
||||
|
||||
// Watermark Gamma
|
||||
if let Some(watermark_gamma) = watermark_gamma {
|
||||
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||
envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||
}
|
||||
|
||||
// Watermark Delta
|
||||
if let Some(watermark_delta) = watermark_delta {
|
||||
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||
}
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
tracing::info!("Starting shard");
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_argv)
|
||||
.envs(env)
|
||||
.args(shard_args)
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
@ -437,30 +428,23 @@ fn shard_manager(
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
} else {
|
||||
}
|
||||
{
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, Some(err.to_string()))))
|
||||
.unwrap();
|
||||
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||
let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||
|
||||
//stdout tracing thread
|
||||
thread::spawn(move || {
|
||||
// Enter shard-manager tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
|
||||
for line in shard_stdout_reader.lines() {
|
||||
// Parse loguru logs
|
||||
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
||||
log.trace();
|
||||
}
|
||||
}
|
||||
log_lines(shard_stdout_reader.lines());
|
||||
});
|
||||
|
||||
let mut ready = false;
|
||||
@ -469,30 +453,25 @@ fn shard_manager(
|
||||
loop {
|
||||
// Process exited
|
||||
if let Some(exit_status) = p.try_wait().unwrap() {
|
||||
// We read stderr in another thread as it seems that `read_to_string` can block
|
||||
// indefinitely in some cases
|
||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||
let (err_sender, err_receiver) = mpsc::channel();
|
||||
thread::spawn(move || {
|
||||
let mut err = String::new();
|
||||
shard_stderr_reader.read_to_string(&mut err).unwrap();
|
||||
err_sender.send(err).unwrap_or(());
|
||||
for line in shard_stderr_reader.lines().flatten() {
|
||||
err_sender.send(line).unwrap_or(());
|
||||
}
|
||||
});
|
||||
let mut err = String::new();
|
||||
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
|
||||
err = err + "\n" + &line;
|
||||
}
|
||||
|
||||
let err = err_receiver
|
||||
.recv_timeout(Duration::from_millis(100))
|
||||
.map_err(|err| {
|
||||
tracing::error!("Unable to read shard {rank} error from stderr");
|
||||
err
|
||||
})
|
||||
.ok();
|
||||
tracing::error!("Shard complete standard error output:\n{err}");
|
||||
|
||||
if let Some(signal) = exit_status.signal() {
|
||||
tracing::error!("Shard process was signaled to shutdown with signal {signal}");
|
||||
}
|
||||
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, err)))
|
||||
.unwrap();
|
||||
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -500,17 +479,17 @@ fn shard_manager(
|
||||
if shutdown.load(Ordering::SeqCst) {
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait();
|
||||
tracing::info!("Shard {rank} terminated");
|
||||
tracing::info!("Shard terminated");
|
||||
return;
|
||||
}
|
||||
|
||||
// Shard is ready
|
||||
if uds.exists() && !ready {
|
||||
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
||||
tracing::info!("Shard ready in {:?}", start_time.elapsed());
|
||||
status_sender.send(ShardStatus::Ready).unwrap();
|
||||
ready = true;
|
||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||
tracing::info!("Waiting for shard {rank} to be ready...");
|
||||
tracing::info!("Waiting for shard to be ready...");
|
||||
wait_time = Instant::now();
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
@ -579,6 +558,23 @@ impl PythonLogMessage {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for PythonLogMessage {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
serde_json::from_str::<Self>(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
|
||||
for line in lines.flatten() {
|
||||
match PythonLogMessage::try_from(&line) {
|
||||
Ok(log) => log.trace(),
|
||||
Err(_) => tracing::debug!("{line}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_num_shards(
|
||||
sharded: Option<bool>,
|
||||
num_shard: Option<usize>,
|
||||
@ -632,7 +628,10 @@ enum LauncherError {
|
||||
}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
let mut download_argv = vec![
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
||||
let mut download_args = vec![
|
||||
"download-weights".to_string(),
|
||||
args.model_id.to_string(),
|
||||
"--extension".to_string(),
|
||||
@ -644,35 +643,35 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = &args.revision {
|
||||
download_argv.push("--revision".to_string());
|
||||
download_argv.push(revision.to_string())
|
||||
download_args.push("--revision".to_string());
|
||||
download_args.push(revision.to_string())
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// If huggingface_hub_cache is set, pass it to the download process
|
||||
// Useful when running inside a docker container
|
||||
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
|
||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
|
||||
env.push((
|
||||
envs.push((
|
||||
"HF_HUB_ENABLE_HF_TRANSFER".into(),
|
||||
enable_hf_transfer.into(),
|
||||
));
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// If args.weights_cache_override is some, pass it to the download process
|
||||
// Useful when running inside a HuggingFace Inference Endpoint
|
||||
if let Some(weights_cache_override) = &args.weights_cache_override {
|
||||
env.push((
|
||||
envs.push((
|
||||
"WEIGHTS_CACHE_OVERRIDE".into(),
|
||||
weights_cache_override.into(),
|
||||
));
|
||||
@ -681,8 +680,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Command::new("text-generation-server")
|
||||
.args(download_argv)
|
||||
.envs(env)
|
||||
.args(download_args)
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
@ -693,6 +692,8 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
} else {
|
||||
tracing::error!("{}", err);
|
||||
}
|
||||
|
||||
return Err(LauncherError::DownloadError);
|
||||
@ -701,16 +702,10 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let download_stdout = download_process.stdout.take().unwrap();
|
||||
let stdout = BufReader::new(download_stdout);
|
||||
|
||||
thread::spawn(move || {
|
||||
// Enter download tracing span
|
||||
let stdout = BufReader::new(download_stdout);
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
for line in stdout.lines() {
|
||||
// Parse loguru logs
|
||||
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
|
||||
log.trace();
|
||||
}
|
||||
}
|
||||
log_lines(stdout.lines());
|
||||
});
|
||||
|
||||
loop {
|
||||
@ -738,10 +733,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
return Err(LauncherError::DownloadError);
|
||||
}
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap();
|
||||
tracing::info!("Waiting for download process to gracefully shutdown");
|
||||
download_process.wait().unwrap();
|
||||
tracing::info!("Download process terminated");
|
||||
terminate("download", download_process, Duration::from_secs(10)).unwrap();
|
||||
return Ok(());
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
@ -760,16 +752,6 @@ fn spawn_shards(
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
if args.revision.is_none() {
|
||||
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
|
||||
}
|
||||
}
|
||||
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
@ -828,11 +810,8 @@ fn spawn_shards(
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
Ok(ShardStatus::Failed(rank)) => {
|
||||
tracing::error!("Shard {rank} failed to start");
|
||||
if let Some(err) = err {
|
||||
tracing::error!("{err}");
|
||||
}
|
||||
shutdown_shards(shutdown, shutdown_receiver);
|
||||
return Err(LauncherError::ShardCannotStart);
|
||||
}
|
||||
@ -854,7 +833,7 @@ fn spawn_webserver(
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut argv = vec![
|
||||
let mut router_args = vec![
|
||||
"--max-concurrent-requests".to_string(),
|
||||
args.max_concurrent_requests.to_string(),
|
||||
"--max-best-of".to_string(),
|
||||
@ -867,8 +846,6 @@ fn spawn_webserver(
|
||||
args.max_total_tokens.to_string(),
|
||||
"--max-batch-prefill-tokens".to_string(),
|
||||
args.max_batch_prefill_tokens.to_string(),
|
||||
"--max-batch-total-tokens".to_string(),
|
||||
args.max_batch_total_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
@ -885,63 +862,54 @@ fn spawn_webserver(
|
||||
args.model_id,
|
||||
];
|
||||
|
||||
// Model optional max batch total tokens
|
||||
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
router_args.push("--max-batch-total-tokens".to_string());
|
||||
router_args.push(max_batch_total_tokens.to_string());
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = args.revision {
|
||||
argv.push("--revision".to_string());
|
||||
argv.push(revision.to_string())
|
||||
router_args.push("--revision".to_string());
|
||||
router_args.push(revision.to_string())
|
||||
}
|
||||
|
||||
if args.json_output {
|
||||
argv.push("--json-output".to_string());
|
||||
router_args.push("--json-output".to_string());
|
||||
}
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = args.otlp_endpoint {
|
||||
argv.push("--otlp-endpoint".to_string());
|
||||
argv.push(otlp_endpoint);
|
||||
router_args.push("--otlp-endpoint".to_string());
|
||||
router_args.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// CORS origins
|
||||
for origin in args.cors_allow_origin.into_iter() {
|
||||
argv.push("--cors-allow-origin".to_string());
|
||||
argv.push(origin);
|
||||
router_args.push("--cors-allow-origin".to_string());
|
||||
router_args.push(origin);
|
||||
}
|
||||
|
||||
// Ngrok
|
||||
if args.ngrok {
|
||||
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
|
||||
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||
LauncherError::WebserverCannotStart
|
||||
})?;
|
||||
|
||||
argv.push("--ngrok".to_string());
|
||||
argv.push("--ngrok-authtoken".to_string());
|
||||
argv.push(authtoken);
|
||||
|
||||
if let Some(domain) = args.ngrok_domain {
|
||||
argv.push("--ngrok-domain".to_string());
|
||||
argv.push(domain);
|
||||
}
|
||||
|
||||
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
|
||||
argv.push("--ngrok-username".to_string());
|
||||
argv.push(username);
|
||||
argv.push("--ngrok-password".to_string());
|
||||
argv.push(password);
|
||||
}
|
||||
router_args.push("--ngrok".to_string());
|
||||
router_args.push("--ngrok-authtoken".to_string());
|
||||
router_args.push(args.ngrok_authtoken.unwrap());
|
||||
router_args.push("--ngrok-edge".to_string());
|
||||
router_args.push(args.ngrok_edge.unwrap());
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Parse Inference API token
|
||||
if let Ok(api_token) = env::var("HF_API_TOKEN") {
|
||||
env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
let mut webserver = match Command::new("text-generation-router")
|
||||
.args(argv)
|
||||
.envs(env)
|
||||
.args(router_args)
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
@ -979,14 +947,49 @@ fn spawn_webserver(
|
||||
Ok(webserver)
|
||||
}
|
||||
|
||||
fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {
|
||||
tracing::info!("Terminating {process_name}");
|
||||
|
||||
let terminate_time = Instant::now();
|
||||
signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();
|
||||
|
||||
tracing::info!("Waiting for {process_name} to gracefully shutdown");
|
||||
|
||||
while terminate_time.elapsed() < timeout {
|
||||
if let Some(status) = process.try_wait()? {
|
||||
tracing::info!("{process_name} terminated");
|
||||
return Ok(status);
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
|
||||
tracing::info!("Killing {process_name}");
|
||||
|
||||
process.kill()?;
|
||||
let exit_status = process.wait()?;
|
||||
|
||||
tracing::info!("{process_name} killed");
|
||||
Ok(exit_status)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), LauncherError> {
|
||||
// Pattern match configuration
|
||||
let args = Args::parse();
|
||||
let args: Args = Args::parse();
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let env_filter =
|
||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
if args.json_output {
|
||||
tracing_subscriber::fmt().json().init();
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.json()
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt().compact().init();
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.compact()
|
||||
.init();
|
||||
}
|
||||
|
||||
if args.env {
|
||||
@ -1008,29 +1011,53 @@ fn main() -> Result<(), LauncherError> {
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
)));
|
||||
}
|
||||
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, args.max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
}
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if args.ngrok {
|
||||
if args.ngrok_authtoken.is_none() {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.ngrok_edge.is_none() {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
@ -1042,6 +1069,11 @@ fn main() -> Result<(), LauncherError> {
|
||||
// Download and convert model weights
|
||||
download_convert_model(&args, running.clone())?;
|
||||
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
// Launcher was asked to stop
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
// Shared shutdown channel
|
||||
@ -1078,11 +1110,8 @@ fn main() -> Result<(), LauncherError> {
|
||||
let mut exit_code = Ok(());
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} crashed");
|
||||
if let Some(err) = err {
|
||||
tracing::error!("{err}");
|
||||
}
|
||||
exit_code = Err(LauncherError::ShardFailed);
|
||||
break;
|
||||
};
|
||||
@ -1100,10 +1129,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
}
|
||||
|
||||
// Graceful termination
|
||||
signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap();
|
||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||
webserver.wait().unwrap();
|
||||
tracing::info!("Webserver terminated");
|
||||
terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
exit_code
|
||||
|
@ -198,9 +198,10 @@ message DecodeResponse {
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
/// Maximum number of tokens that the client will send
|
||||
uint32 max_total_tokens = 2;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message WarmupResponse {}
|
||||
message WarmupResponse {
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
}
|
||||
|
@ -103,8 +103,7 @@ impl Client {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
|
||||
@ -143,13 +142,9 @@ impl Client {
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
self.stub.warmup(request).await?.into_inner();
|
||||
Ok(())
|
||||
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
|
@ -95,14 +95,11 @@ impl ShardedClient {
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
||||
})
|
||||
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
|
@ -53,7 +53,7 @@ impl Infer {
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding);
|
||||
let queue = Queue::new(requires_padding, 16);
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
@ -37,8 +37,8 @@ struct Args {
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(default_value = "16000", long, env)]
|
||||
max_batch_total_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
@ -49,8 +49,8 @@ struct Args {
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
@ -64,11 +64,7 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_domain: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_username: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_password: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), RouterError> {
|
||||
@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> {
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_domain,
|
||||
ngrok_username,
|
||||
ngrok_password,
|
||||
ngrok_edge,
|
||||
} = args;
|
||||
|
||||
// Validate args
|
||||
@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> {
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
}
|
||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
// Finally, convert to AllowOrigin
|
||||
@ -147,7 +145,7 @@ fn main() -> Result<(), RouterError> {
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let params = FromPretrainedParameters {
|
||||
revision: revision.clone(),
|
||||
revision: revision.clone().unwrap_or("main".to_string()),
|
||||
auth_token: authorization_token.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
@ -175,7 +173,7 @@ fn main() -> Result<(), RouterError> {
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token)
|
||||
false => get_model_info(&tokenizer_name, revision, authorization_token)
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> {
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
||||
.await
|
||||
.map_err(RouterError::Warmup)?;
|
||||
.map_err(RouterError::Warmup)?
|
||||
{
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||
);
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
max_batch_total_tokens
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
max_supported_batch_total_tokens
|
||||
}
|
||||
};
|
||||
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> {
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_supported_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> {
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_domain,
|
||||
ngrok_username,
|
||||
ngrok_password,
|
||||
ngrok_edge,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
@ -316,9 +333,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
revision: Option<String>,
|
||||
token: Option<String>,
|
||||
) -> Option<HubModelInfo> {
|
||||
let revision = match revision {
|
||||
None => {
|
||||
tracing::warn!("`--revision` is not set");
|
||||
tracing::warn!("We strongly advise to set it to a known supported commit.");
|
||||
"main".to_string()
|
||||
}
|
||||
Some(revision) => revision,
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace('/', "%2F");
|
||||
@ -331,9 +357,18 @@ pub async fn get_model_info(
|
||||
let response = builder.send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||
let hub_model_info: HubModelInfo =
|
||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||
if let Some(sha) = &hub_model_info.sha {
|
||||
tracing::info!(
|
||||
"Serving revision {sha} of model {}",
|
||||
hub_model_info.model_id
|
||||
);
|
||||
}
|
||||
Some(hub_model_info)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -33,12 +33,12 @@ pub(crate) struct Queue {
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(requires_padding: bool) -> Self {
|
||||
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(requires_padding, queue_receiver));
|
||||
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
||||
|
||||
Self { queue_sender }
|
||||
}
|
||||
@ -81,8 +81,12 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
|
||||
let mut state = State::new(requires_padding);
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
receiver: flume::Receiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
@ -119,15 +123,19 @@ struct State {
|
||||
|
||||
/// Whether the model is using padding
|
||||
requires_padding: bool,
|
||||
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(requires_padding: bool) -> Self {
|
||||
fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
requires_padding,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,10 +195,21 @@ impl State {
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||
} else {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
// pad to block size
|
||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
if self.requires_padding {
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
} else {
|
||||
// pad to block size
|
||||
decode_tokens +=
|
||||
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
||||
/ self.block_size)
|
||||
* self.block_size;
|
||||
}
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens) > token_budget
|
||||
@ -321,7 +340,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_append() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
@ -337,7 +356,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_empty() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
|
||||
assert!(state.next_batch(None, 1, 1).is_none());
|
||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||
@ -345,7 +364,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -377,7 +396,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false);
|
||||
let mut state = State::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
@ -410,14 +429,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
|
||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||
@ -425,7 +444,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -458,7 +477,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
@ -483,7 +502,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false);
|
||||
let queue = Queue::new(false, 1);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
@ -524,9 +524,7 @@ pub async fn run(
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
ngrok: bool,
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_domain: Option<String>,
|
||||
ngrok_username: Option<String>,
|
||||
ngrok_password: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
@ -696,32 +694,25 @@ pub async fn run(
|
||||
#[cfg(feature = "ngrok")]
|
||||
{
|
||||
use ngrok::config::TunnelBuilder;
|
||||
use ngrok::tunnel::UrlTunnel;
|
||||
|
||||
let _ = addr;
|
||||
|
||||
let authtoken =
|
||||
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||
|
||||
let mut tunnel = ngrok::Session::builder()
|
||||
let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");
|
||||
|
||||
let tunnel = ngrok::Session::builder()
|
||||
.authtoken(authtoken)
|
||||
.connect()
|
||||
.await
|
||||
.unwrap()
|
||||
.http_endpoint();
|
||||
|
||||
if let Some(domain) = ngrok_domain {
|
||||
tunnel = tunnel.domain(domain);
|
||||
}
|
||||
|
||||
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
|
||||
tunnel = tunnel.basic_auth(username, password);
|
||||
}
|
||||
.labeled_tunnel()
|
||||
.label("edge", edge);
|
||||
|
||||
let listener = tunnel.listen().await.unwrap();
|
||||
|
||||
// Run server
|
||||
tracing::info!("Ingress URL: {:?}", listener.url());
|
||||
axum::Server::builder(listener)
|
||||
.serve(app.into_make_service())
|
||||
//Wait until all requests are finished to shut down
|
||||
|
@ -1,4 +1,5 @@
|
||||
include Makefile-flash-att
|
||||
include Makefile-flash-att-v2
|
||||
include Makefile-vllm
|
||||
|
||||
unit-tests:
|
||||
|
13
server/Makefile-flash-att-v2
Normal file
13
server/Makefile-flash-att-v2
Normal file
@ -0,0 +1,13 @@
|
||||
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||
|
||||
flash-attention-v2:
|
||||
# Clone flash attention
|
||||
pip install packaging
|
||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||
|
||||
build-flash-attention-v2: flash-attention-v2
|
||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||
cd flash-attention-v2 && python setup.py build
|
||||
|
||||
install-flash-attention-v2: build-flash-attention-v2
|
||||
cd flash-attention-v2 && python setup.py install
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "0.9.1"
|
||||
version = "0.9.3"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
|
@ -194,6 +194,8 @@ def quantize(
|
||||
percdamp: float = 0.01,
|
||||
act_order: bool = False,
|
||||
):
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
@ -207,6 +209,7 @@ def quantize(
|
||||
bits=4,
|
||||
groupsize=128,
|
||||
output_dir=output_dir,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
upload_to_model_id=upload_to_model_id,
|
||||
percdamp=percdamp,
|
||||
|
@ -42,51 +42,21 @@ __all__ = [
|
||||
"get_model",
|
||||
]
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires CUDA and Flash Attention kernels to be installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
)
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = True
|
||||
try:
|
||||
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
if not torch.cuda.is_available():
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires CUDA. No compatible CUDA devices found."
|
||||
)
|
||||
raise ImportError("CUDA is not available")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
supported = is_sm75 or is_sm8x or is_sm90
|
||||
if not supported:
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
||||
"No compatible CUDA device found."
|
||||
)
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
)
|
||||
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
FLASH_ATTENTION = True
|
||||
else:
|
||||
FLASH_ATTENTION = False
|
||||
except ImportError:
|
||||
logger.opt(exception=True).warning(
|
||||
"Could not import Flash Attention enabled models"
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -23,25 +23,77 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import dropout_layer_norm
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module):
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module):
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
w = [
|
||||
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
|
||||
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
|
||||
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
|
||||
]
|
||||
weight = torch.cat(w, dim=0)
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
bias = None
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -118,22 +192,29 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
@ -323,6 +403,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
attention(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
if self.num_heads_kv == 1:
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -5,13 +5,11 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -265,26 +263,15 @@ class FlashMQAttention(torch.nn.Module):
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# Expand from 1 to num_heads
|
||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
attention(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -712,14 +712,14 @@ class FlashCausalLM(Model):
|
||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
global CACHE_MANAGER
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
try:
|
||||
CACHE_MANAGER = CacheManager(
|
||||
# Adds some wiggle room
|
||||
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
||||
batch.blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
@ -729,11 +729,43 @@ class FlashCausalLM(Model):
|
||||
_, batch = self.generate_token(batch)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize(self.device)
|
||||
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
||||
|
||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
|
||||
# 0.98 to add some wiggle room
|
||||
num_blocks = (
|
||||
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.blocks
|
||||
)
|
||||
|
||||
del CACHE_MANAGER
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
|
@ -2,13 +2,13 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
LlamaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
config = LlamaConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ class FlashLlama(FlashCausalLM):
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -107,8 +107,9 @@ class Model(ABC):
|
||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B, max_total_tokens: int):
|
||||
def warmup(self, batch: B) -> Optional[int]:
|
||||
self.generate_token(batch)
|
||||
return None
|
||||
|
||||
def decode_token(
|
||||
self,
|
||||
|
@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
filtered_batch = batch.filter(request.request_ids)
|
||||
self.cache.set(filtered_batch)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
self.model.warmup(batch, request.max_total_tokens)
|
||||
max_supported_total_tokens = self.model.warmup(batch)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generate_pb2.WarmupResponse()
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens
|
||||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
batch = self.model.batch_type.from_pb(
|
||||
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
|
||||
if len(batches) > 1:
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
|
@ -94,6 +94,14 @@ def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: Lis
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
|
||||
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
||||
# Skip blacklisted files
|
||||
if (
|
||||
"arguments" in pt_file.name
|
||||
or "args" in pt_file.name
|
||||
or "training" in pt_file.name
|
||||
):
|
||||
continue
|
||||
|
||||
start = datetime.datetime.now()
|
||||
convert_file(pt_file, sf_file, discard_names)
|
||||
elapsed = datetime.datetime.now() - start
|
||||
|
124
server/text_generation_server/utils/flash_attn.py
Normal file
124
server/text_generation_server/utils/flash_attn.py
Normal file
@ -0,0 +1,124 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError("CUDA is not available")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2 = False
|
||||
try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
||||
)
|
||||
if not (is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2 = True
|
||||
except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
|
||||
if not (is_sm75 or is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
):
|
||||
if HAS_FLASH_ATTN_V2:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
return flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
raise NotImplementedError("flash attention is not installed")
|
@ -13,6 +13,9 @@ import transformers
|
||||
from huggingface_hub import HfApi
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||
from text_generation_server.utils.hub import weight_files
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
@ -38,7 +41,6 @@ class Quantizer(nn.Module):
|
||||
maxshrink=0.8,
|
||||
trits=False,
|
||||
):
|
||||
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
@ -600,6 +602,8 @@ def sequential(
|
||||
nsamples,
|
||||
bits,
|
||||
groupsize,
|
||||
*,
|
||||
hooks,
|
||||
percdamp=0.01,
|
||||
sym: bool = False,
|
||||
act_order: bool = False,
|
||||
@ -637,7 +641,7 @@ def sequential(
|
||||
layers[0] = Catcher(layers[0])
|
||||
for batch in dataloader:
|
||||
try:
|
||||
model(batch[0])
|
||||
model(batch[0].cuda())
|
||||
except ValueError:
|
||||
pass
|
||||
layers[0] = layers[0].module
|
||||
@ -646,6 +650,8 @@ def sequential(
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# model.model.norm = model.model.norm.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
outs = torch.zeros_like(inps)
|
||||
|
||||
@ -662,10 +668,8 @@ def sequential(
|
||||
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
|
||||
print("+==================+==============+============+===========+=======+")
|
||||
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
|
||||
layer = layers[i].to(dev)
|
||||
remove_hook_from_submodules(layer)
|
||||
layer = layers[i]
|
||||
layer.load()
|
||||
full = find_layers(layer)
|
||||
sequential = [list(full.keys())]
|
||||
|
||||
@ -677,6 +681,7 @@ def sequential(
|
||||
gptq[name].quantizer.configure(
|
||||
bits, perchannel=True, sym=sym, mse=False
|
||||
)
|
||||
pass
|
||||
|
||||
def add_batch(name):
|
||||
def tmp(_, inp, out):
|
||||
@ -688,7 +693,6 @@ def sequential(
|
||||
for name in subset:
|
||||
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
||||
for j in range(nsamples):
|
||||
|
||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||
for h in handles:
|
||||
h.remove()
|
||||
@ -714,7 +718,7 @@ def sequential(
|
||||
for j in range(nsamples):
|
||||
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
|
||||
|
||||
layers[i] = layer.cpu()
|
||||
layer.unload()
|
||||
del layer
|
||||
del gptq
|
||||
torch.cuda.empty_cache()
|
||||
@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize):
|
||||
return model
|
||||
|
||||
|
||||
def setdeepattr(module, full_name, tensor):
|
||||
current = module
|
||||
tokens = full_name.split(".")
|
||||
for token in tokens[:-1]:
|
||||
current = getattr(current, token)
|
||||
setattr(current, tokens[-1], tensor)
|
||||
|
||||
|
||||
def getdeepattr(module, full_name):
|
||||
current = module
|
||||
tokens = full_name.split(".")
|
||||
for token in tokens:
|
||||
current = getattr(current, token)
|
||||
return current
|
||||
|
||||
|
||||
def load_weights_pre_hook(module_name, weights, recursive=False):
|
||||
def inner(module, args):
|
||||
print(f"Pre hook {module_name}")
|
||||
local_params = {}
|
||||
for k, v in module.named_parameters():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for k, v in module.named_buffers():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
|
||||
for local_param in local_params:
|
||||
current_tensor = getdeepattr(module, local_param)
|
||||
if current_tensor.device == torch.device("meta"):
|
||||
# print(f"Loading {local_param}")
|
||||
if module_name:
|
||||
tensor_name = f"{module_name}.{local_param}"
|
||||
else:
|
||||
tensor_name = local_param
|
||||
tensor = weights.get_tensor(tensor_name)
|
||||
setdeepattr(module, local_param, nn.Parameter(tensor))
|
||||
else:
|
||||
setdeepattr(
|
||||
module,
|
||||
local_param,
|
||||
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def load_weights_post_hook(module_name, weights, recursive=False):
|
||||
def inner(module, args, output):
|
||||
print(f"Post hook {module_name}")
|
||||
local_params = {}
|
||||
for k, v in module.named_parameters():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for k, v in module.named_buffers():
|
||||
if not recursive and k.count(".") != 1:
|
||||
continue
|
||||
local_params[k] = v
|
||||
for local_param in local_params:
|
||||
# print(f"Unloading {local_param}")
|
||||
current_tensor = getdeepattr(module, local_param)
|
||||
setdeepattr(
|
||||
module,
|
||||
local_param,
|
||||
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
|
||||
)
|
||||
return output
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def quantize(
|
||||
model_id: str,
|
||||
bits: int,
|
||||
groupsize: int,
|
||||
output_dir: str,
|
||||
revision: str,
|
||||
trust_remote_code: bool,
|
||||
upload_to_model_id: Optional[str],
|
||||
percdamp: float,
|
||||
act_order: bool,
|
||||
):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced_low_0",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
|
||||
model = model.eval()
|
||||
|
||||
print("LOADED model")
|
||||
files = weight_files(model_id, revision, extension=".safetensors")
|
||||
process_group, _, _ = initialize_torch_distributed()
|
||||
weights = Weights(
|
||||
files,
|
||||
device=torch.device("cuda:0"),
|
||||
dtype=torch.float16,
|
||||
process_group=process_group,
|
||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||
)
|
||||
hooks = []
|
||||
for name, module in model.named_modules():
|
||||
|
||||
def load(module, name):
|
||||
def _load():
|
||||
load_weights_pre_hook(name, weights, recursive=True)(module, None)
|
||||
|
||||
return _load
|
||||
|
||||
def unload(module, name):
|
||||
def _unload():
|
||||
load_weights_post_hook(name, weights, recursive=True)(
|
||||
module, None, None
|
||||
)
|
||||
|
||||
return _unload
|
||||
|
||||
module.load = load(module, name)
|
||||
module.unload = unload(module, name)
|
||||
hooks.append(
|
||||
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
|
||||
)
|
||||
hooks.append(
|
||||
module.register_forward_hook(load_weights_post_hook(name, weights))
|
||||
)
|
||||
model.seqlen = 2048
|
||||
|
||||
dataset = "wikitext2"
|
||||
@ -806,6 +922,7 @@ def quantize(
|
||||
groupsize,
|
||||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
hooks=hooks,
|
||||
)
|
||||
print(time.time() - tick)
|
||||
|
||||
@ -858,7 +975,6 @@ def quantize(
|
||||
logger.info("Saved tokenizer")
|
||||
|
||||
if upload_to_model_id:
|
||||
|
||||
api = HfApi()
|
||||
|
||||
api.upload_folder(
|
||||
|
Loading…
Reference in New Issue
Block a user