diff --git a/Cargo.lock b/Cargo.lock index 867503f2..539cf124 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,18 +148,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] name = "async-trait" -version = "0.1.68" +version = "0.1.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -410,9 +410,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.10" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384e169cc618c613d5e3ca6404dda77a8685a63e08660dcc64abaf7da7cb0c7a" +checksum = "1640e5cc7fb47dbb8338fd471b105e7ed6c3cb2aeb00c2e067127ffd3764a05d" dependencies = [ "clap_builder", "clap_derive", @@ -421,9 +421,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.10" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef137bbe35aab78bdb468ccfba75a5f4d8321ae011d34063770780545176af2d" +checksum = "98c59138d527eeaf9b53f35a77fcc1fad9d883116070c63d5de1c7dc7b00c72b" dependencies = [ "anstream", "anstyle", @@ -440,7 +440,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -492,9 +492,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e69e28e9f7f77debdedbaafa2866e1de9ba56df55a8bd7cfc724c25a09987c" +checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" dependencies = [ "libc", ] @@ -633,12 +633,12 @@ dependencies = [ [[package]] name = "dashmap" -version = "5.4.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d" dependencies = [ "cfg-if", - "hashbrown 0.12.3", + "hashbrown 0.14.0", "lock_api", "once_cell", "parking_lot_core", @@ -736,6 +736,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -924,7 +930,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -1023,7 +1029,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -1038,13 +1044,19 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.13.2" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038" dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.1" @@ -1053,9 +1065,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" [[package]] name = "hmac" @@ -1190,6 +1202,16 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", "serde", ] @@ -1254,12 +1276,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.1", + "rustix 0.38.4", "windows-sys 0.48.0", ] @@ -1292,9 +1314,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "jobserver" @@ -1397,7 +1419,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" dependencies = [ - "regex-automata", + "regex-automata 0.1.10", ] [[package]] @@ -1432,9 +1454,9 @@ dependencies = [ [[package]] name = "metrics" -version = "0.21.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa8ebbd1a9e57bbab77b9facae7f5136aea44c356943bf9a198f647da64285d6" +checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" dependencies = [ "ahash", "metrics-macros", @@ -1449,7 +1471,7 @@ checksum = "8a4964177ddfdab1e3a2b37aec7cf320e14169abb0ed73999f558136409178d5" dependencies = [ "base64 0.21.2", "hyper", - "indexmap", + "indexmap 1.9.3", "ipnet", "metrics", "metrics-util", @@ -1467,18 +1489,18 @@ checksum = "ddece26afd34c31585c74a4db0630c376df271c285d682d1e55012197830b6df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] name = "metrics-util" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "111cb375987443c3de8d503580b536f77dc8416d32db62d9456db5d93bd7ac47" +checksum = "4de2ed6e491ed114b40b732e4d1659a9d53992ebd87490c44a6ffe23739d973e" dependencies = [ "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.13.2", + "hashbrown 0.13.1", "metrics", "num_cpus", "quanta", @@ -1530,9 +1552,9 @@ dependencies = [ [[package]] name = "monostate" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" +checksum = "3f3f57a8802842f648026a33c3d2e3bb41bb309a35b1609bd7ef2b060b8b6b1b" dependencies = [ "monostate-impl", "serde", @@ -1540,13 +1562,13 @@ dependencies = [ [[package]] name = "monostate-impl" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" +checksum = "e72f4d2e10fde62a0f2fcb4b44ccbf4f9899dcc30c9193449f8dfb9123d71377" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -1701,6 +1723,15 @@ dependencies = [ "libc", ] +[[package]] +name = "num_threads" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" +dependencies = [ + "libc", +] + [[package]] name = "number_prefix" version = "0.3.0" @@ -1773,7 +1804,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -1854,7 +1885,7 @@ dependencies = [ "fnv", "futures-channel", "futures-util", - "indexmap", + "indexmap 1.9.3", "js-sys", "once_cell", "pin-project-lite", @@ -1870,7 +1901,7 @@ dependencies = [ "fnv", "futures-channel", "futures-util", - "indexmap", + "indexmap 1.9.3", "once_cell", "pin-project-lite", "thiserror", @@ -1974,9 +2005,9 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "b4b27ab7be369122c218afc2079489cdcb4b517c0a3fc386ff11e1fedfcc2b35" [[package]] name = "pbkdf2" @@ -2003,34 +2034,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] name = "pin-project" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e138fdd8263907a2b0e1b4e80b7e58c721126479b6e6eedfb1b402acea7b9bd" +checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1fef411b303e3e12d534fb6e7852de82da56edd937d895125821fb7c09436c7" +checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "4c40d25201921e5ff0c862a505c6557ea88568a4e3ace775ab55e93f2f4f9d57" [[package]] name = "pin-utils" @@ -2046,9 +2077,9 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "portable-atomic" -version = "1.3.3" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "767eb9f07d4a5ebcb39bbf2d452058a93c011373abf6832e24194a1c3f004794" +checksum = "d220334a184db82b31b83f5ff093e3315280fb2b6bbc032022b2304a509aab7a" [[package]] name = "ppv-lite86" @@ -2092,9 +2123,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" +checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" dependencies = [ "unicode-ident", ] @@ -2294,13 +2325,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.4" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" dependencies = [ "aho-corasick 1.0.2", "memchr", - "regex-syntax 0.7.2", + "regex-automata 0.3.3", + "regex-syntax 0.7.4", ] [[package]] @@ -2312,6 +2344,17 @@ dependencies = [ "regex-syntax 0.6.29", ] +[[package]] +name = "regex-automata" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310" +dependencies = [ + "aho-corasick 1.0.2", + "memchr", + "regex-syntax 0.7.4", +] + [[package]] name = "regex-syntax" version = "0.6.29" @@ -2320,9 +2363,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "reqwest" @@ -2397,7 +2440,7 @@ dependencies = [ "quote", "rust-embed-utils", "shellexpand", - "syn 2.0.22", + "syn 2.0.25", "walkdir", ] @@ -2428,9 +2471,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.21" +version = "0.37.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" +checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" dependencies = [ "bitflags 1.3.2", "errno", @@ -2442,9 +2485,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.1" +version = "0.38.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" +checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" dependencies = [ "bitflags 2.3.3", "errno", @@ -2476,15 +2519,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" +checksum = "dc31bd9b61a32c31f9650d18add92aa83a49ba979c143eefd27fe7177b05bd5f" [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "fe232bdf6be8c8de797b22184ee71118d63780ea42ac85b61d1baa6d3b782ae9" [[package]] name = "same-file" @@ -2497,11 +2540,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] @@ -2551,29 +2594,29 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.164" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.164" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] name = "serde_json" -version = "1.0.99" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" +checksum = "b5062a995d481b2308b6064e9af76011f2921c35f97b0468811ed9f6cd91dfed" dependencies = [ "itoa", "ryu", @@ -2582,10 +2625,11 @@ dependencies = [ [[package]] name = "serde_path_to_error" -version = "0.1.11" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0" +checksum = "8acc4422959dd87a76cb117c191dcbffc20467f06c9100b76721dab370f24d3a" dependencies = [ + "itoa", "serde", ] @@ -2697,9 +2741,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" [[package]] name = "socket2" @@ -2769,9 +2813,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.22" +version = "2.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efbeae7acf4eabd6bcdcbd11c92f45231ddda7539edc7806bd1a04a03b24616" +checksum = "15e3fc8c0c74267e2df136e5e5fb656a464158aa57624053375eb9c8c6e25ae2" dependencies = [ "proc-macro2", "quote", @@ -2786,9 +2830,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sysinfo" -version = "0.29.3" +version = "0.29.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bcd0346f90b6bc83526c7b180039a8acd26a5c848cc556d457f6472eb148122" +checksum = "751e810399bba86e9326f5762b7f32ac5a085542df78da6a78d94e07d14d7c11" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2824,9 +2868,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +checksum = "ec96d2ffad078296368d46ff1cb309be1c23c513b4ab0e22a45de0185275ac96" dependencies = [ "filetime", "libc", @@ -2843,13 +2887,13 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.37.21", + "rustix 0.37.23", "windows-sys 0.48.0", ] [[package]] name = "text-generation-benchmark" -version = "0.9.1" +version = "0.9.3" dependencies = [ "average", "clap", @@ -2869,7 +2913,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "0.9.1" +version = "0.9.3" dependencies = [ "futures", "grpc-metadata", @@ -2885,7 +2929,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "0.9.1" +version = "0.9.3" dependencies = [ "clap", "ctrlc", @@ -2901,7 +2945,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "0.9.1" +version = "0.9.3" dependencies = [ "async-stream", "axum", @@ -2934,22 +2978,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -2964,11 +3008,13 @@ dependencies = [ [[package]] name = "time" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446" dependencies = [ "itoa", + "libc", + "num_threads", "serde", "time-core", "time-macros", @@ -2982,9 +3028,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4" dependencies = [ "time-core", ] @@ -3078,7 +3124,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -3209,7 +3255,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap", + "indexmap 1.9.3", "pin-project", "pin-project-lite", "rand", @@ -3291,7 +3337,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", ] [[package]] @@ -3413,9 +3459,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73" [[package]] name = "unicode-normalization" @@ -3484,11 +3530,11 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "utoipa" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98" +checksum = "520434cac5c98120177d5cc15be032703f6dca7d5ef82e725c798113b375000a" dependencies = [ - "indexmap", + "indexmap 2.0.0", "serde", "serde_json", "utoipa-gen", @@ -3496,21 +3542,22 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.3.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b" +checksum = "6e22e88a487b6e0374533871b79b1f5ded05671bd0936bd547eb42f82fb9060d" dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.22", + "regex", + "syn 2.0.25", ] [[package]] name = "utoipa-swagger-ui" -version = "3.1.3" +version = "3.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "062bba5a3568e126ac72049a63254f4cb1da2eb713db0c1ab2a4c76be191db8c" +checksum = "4602d7100d3cfd8a086f30494e68532402ab662fa366c9d201d677e33cee138d" dependencies = [ "axum", "mime_guess", @@ -3536,9 +3583,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.2.1" +version = "8.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b3c89c2c7e50f33e4d35527e5bf9c11d6d132226dbbd1753f0fbe9f19ef88c6" +checksum = "bbc5ad0d9d26b2c49a5ab7da76c3e79d3ee37e7821799f8223fcb8f2f391a2e7" dependencies = [ "anyhow", "rustc_version", @@ -3599,7 +3646,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", "wasm-bindgen-shared", ] @@ -3633,7 +3680,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.22", + "syn 2.0.25", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3706,21 +3753,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows-sys" -version = "0.42.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", -] - [[package]] name = "windows-sys" version = "0.45.0" diff --git a/Cargo.toml b/Cargo.toml index fdfb274e..49b7717a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.9.1" +version = "0.9.3" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/Dockerfile b/Dockerfile index 66e0091d..168f2f97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile # Build specific version of flash attention RUN make build-flash-attention +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2 + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/README.md b/README.md index d31c176b..43388d00 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@
+![image](https://github.com/huggingface/text-generation-inference/assets/3841370/38ba1531-ea0d-4851-b31a-a6d4ddc944b0) + # Text Generation Inference @@ -11,9 +13,6 @@ Swagger API documentation - -![architecture](assets/architecture.jpg) -
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) @@ -64,6 +63,8 @@ to power LLMs api-inference widgets. - [Starcoder](https://huggingface.co/bigcode/starcoder) - [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b) - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) +- [MPT](https://huggingface.co/mosaicml/mpt-30b) +- [Llama V2](https://huggingface.co/meta-llama) Other architectures are supported on a best effort basis using: @@ -133,6 +134,10 @@ print(text) You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). +### Using on private models or gated models + +You can use `HUGGING_FACE_HUB_TOKEN` environment variable to set the token used by `text-generation-inference` to give access to protected ressources. + ### Distributed Tracing `text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature @@ -212,7 +217,7 @@ sudo apt-get install libssl-dev gcc -y ### CUDA Kernels The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove -the kernels by using the `BUILD_EXTENSIONS=False` environment variable. +the kernels by using the `DISABLE_CUSTOM_KERNELS=True` environment variable. Be aware that the official Docker image has them enabled by default. diff --git a/docs/openapi.json b/docs/openapi.json index aaa360b0..80240460 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "0.9.1" + "version": "0.9.3" }, "paths": { "/": { diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index ae0694da..3e7f86d4 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -13,7 +13,7 @@ nix = "0.26.2" serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.93" tracing = "0.1.37" -tracing-subscriber = { version = "0.3.16", features = ["json"] } +tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } [dev-dependencies] float_eq = "1.0.1" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8b34dfe3..53de36b2 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -4,10 +4,10 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Read}; +use std::io::{BufRead, BufReader, Lines, Read}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; -use std::process::{Child, Command, Stdio}; +use std::process::{Child, Command, ExitStatus, Stdio}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::TryRecvError; use std::sync::{mpsc, Arc}; @@ -15,6 +15,7 @@ use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; +use tracing_subscriber::EnvFilter; mod env_runtime; @@ -41,6 +42,7 @@ impl std::fmt::Display for Quantization { #[derive(Clone, Copy, Debug, ValueEnum)] enum Dtype { Float16, + #[clap(name = "bfloat16")] BFloat16, } @@ -182,8 +184,8 @@ struct Args { /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, /// This setting defines how many tokens can be passed before forcing the waiting /// queries to be put on the batch (if the size of the batch allows for it). @@ -265,17 +267,9 @@ struct Args { #[clap(long, env)] ngrok_authtoken: Option, - /// ngrok domain name where the axum webserver will be available at + /// ngrok edge #[clap(long, env)] - ngrok_domain: Option, - - /// ngrok basic auth username - #[clap(long, env)] - ngrok_username: Option, - - /// ngrok basic auth password - #[clap(long, env)] - ngrok_password: Option, + ngrok_edge: Option, /// Display a lot of information about your runtime environment #[clap(long, short, action)] @@ -285,7 +279,7 @@ struct Args { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, Option)), + Failed(usize), } #[allow(clippy::too_many_arguments)] @@ -310,6 +304,9 @@ fn shard_manager( shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, ) { + // Enter shard-manager tracing span + let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); + // Get UDS path let uds_string = format!("{uds_path}-{rank}"); let uds = Path::new(&uds_string); @@ -319,7 +316,7 @@ fn shard_manager( } // Process args - let mut shard_argv = vec![ + let mut shard_args = vec![ "serve".to_string(), model_id, "--uds-path".to_string(), @@ -331,77 +328,71 @@ fn shard_manager( // Activate trust remote code if trust_remote_code { - shard_argv.push("--trust-remote-code".to_string()); + shard_args.push("--trust-remote-code".to_string()); } // Activate tensor parallelism if world_size > 1 { - shard_argv.push("--sharded".to_string()); + shard_args.push("--sharded".to_string()); } if let Some(quantize) = quantize { - shard_argv.push("--quantize".to_string()); - shard_argv.push(quantize.to_string()) + shard_args.push("--quantize".to_string()); + shard_args.push(quantize.to_string()) } if let Some(dtype) = dtype { - shard_argv.push("--dtype".to_string()); - shard_argv.push(dtype.to_string()) + shard_args.push("--dtype".to_string()); + shard_args.push(dtype.to_string()) } // Model optional revision if let Some(revision) = revision { - shard_argv.push("--revision".to_string()); - shard_argv.push(revision) + shard_args.push("--revision".to_string()); + shard_args.push(revision) } // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { - shard_argv.push("--otlp-endpoint".to_string()); - shard_argv.push(otlp_endpoint); + shard_args.push("--otlp-endpoint".to_string()); + shard_args.push(otlp_endpoint); } // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); - - // Use cuda allocator. It leads to less memory fragmentation - env.push(( - "PYTORCH_CUDA_ALLOC_CONF".into(), - "backend:cudaMallocAsync".into(), - )); + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); // Torch Distributed Env vars - env.push(("RANK".into(), rank.to_string().into())); - env.push(("WORLD_SIZE".into(), world_size.to_string().into())); - env.push(("MASTER_ADDR".into(), master_addr.into())); - env.push(("MASTER_PORT".into(), master_port.to_string().into())); - env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); + envs.push(("RANK".into(), rank.to_string().into())); + envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); + envs.push(("MASTER_ADDR".into(), master_addr.into())); + envs.push(("MASTER_PORT".into(), master_port.to_string().into())); + envs.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into())); // Safetensors load fast - env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); // Enable hf transfer for insane download speeds let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - env.push(( + envs.push(( "HF_HUB_ENABLE_HF_TRANSFER".into(), enable_hf_transfer.into(), )); // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { - env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; // If weights_cache_override is some, pass it to the shard // Useful when running inside a HuggingFace Inference Endpoint if let Some(weights_cache_override) = weights_cache_override { - env.push(( + envs.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), )); @@ -409,24 +400,24 @@ fn shard_manager( // If disable_custom_kernels is true, pass it to the shard as an env var if disable_custom_kernels { - env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) + envs.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) } // Watermark Gamma if let Some(watermark_gamma) = watermark_gamma { - env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) + envs.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into())) } // Watermark Delta if let Some(watermark_delta) = watermark_delta { - env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) + envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) } // Start process - tracing::info!("Starting shard {rank}"); + tracing::info!("Starting shard"); let mut p = match Command::new("text-generation-server") - .args(shard_argv) - .envs(env) + .args(shard_args) + .envs(envs) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -437,30 +428,23 @@ fn shard_manager( if err.kind() == io::ErrorKind::NotFound { tracing::error!("text-generation-server not found in PATH"); tracing::error!("Please install it with `make install-server`") - } else { + } + { tracing::error!("{}", err); } - status_sender - .send(ShardStatus::Failed((rank, Some(err.to_string())))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } }; // Redirect STDOUT to the console let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let mut shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + //stdout tracing thread thread::spawn(move || { - // Enter shard-manager tracing span - let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); - for line in shard_stdout_reader.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } + log_lines(shard_stdout_reader.lines()); }); let mut ready = false; @@ -469,30 +453,25 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that `read_to_string` can block - // indefinitely in some cases + // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); thread::spawn(move || { - let mut err = String::new(); - shard_stderr_reader.read_to_string(&mut err).unwrap(); - err_sender.send(err).unwrap_or(()); + for line in shard_stderr_reader.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } }); + let mut err = String::new(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } - let err = err_receiver - .recv_timeout(Duration::from_millis(100)) - .map_err(|err| { - tracing::error!("Unable to read shard {rank} error from stderr"); - err - }) - .ok(); + tracing::error!("Shard complete standard error output:\n{err}"); if let Some(signal) = exit_status.signal() { tracing::error!("Shard process was signaled to shutdown with signal {signal}"); } - status_sender - .send(ShardStatus::Failed((rank, err))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } @@ -500,17 +479,17 @@ fn shard_manager( if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); let _ = p.wait(); - tracing::info!("Shard {rank} terminated"); + tracing::info!("Shard terminated"); return; } // Shard is ready if uds.exists() && !ready { - tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); + tracing::info!("Shard ready in {:?}", start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {rank} to be ready..."); + tracing::info!("Waiting for shard to be ready..."); wait_time = Instant::now(); } sleep(Duration::from_millis(100)); @@ -579,6 +558,23 @@ impl PythonLogMessage { } } +impl TryFrom<&String> for PythonLogMessage { + type Error = serde_json::Error; + + fn try_from(value: &String) -> Result { + serde_json::from_str::(value) + } +} + +fn log_lines(lines: Lines) { + for line in lines.flatten() { + match PythonLogMessage::try_from(&line) { + Ok(log) => log.trace(), + Err(_) => tracing::debug!("{line}"), + } + } +} + fn find_num_shards( sharded: Option, num_shard: Option, @@ -632,7 +628,10 @@ enum LauncherError { } fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { - let mut download_argv = vec![ + // Enter download tracing span + let _span = tracing::span!(tracing::Level::INFO, "download").entered(); + + let mut download_args = vec![ "download-weights".to_string(), args.model_id.to_string(), "--extension".to_string(), @@ -644,35 +643,35 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Model optional revision if let Some(revision) = &args.revision { - download_argv.push("--revision".to_string()); - download_argv.push(revision.to_string()) + download_args.push("--revision".to_string()); + download_args.push(revision.to_string()) } // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); // If huggingface_hub_cache is set, pass it to the download process // Useful when running inside a docker container if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { - env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); + envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; // Enable hf transfer for insane download speeds let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); - env.push(( + envs.push(( "HF_HUB_ENABLE_HF_TRANSFER".into(), enable_hf_transfer.into(), )); // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; // If args.weights_cache_override is some, pass it to the download process // Useful when running inside a HuggingFace Inference Endpoint if let Some(weights_cache_override) = &args.weights_cache_override { - env.push(( + envs.push(( "WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into(), )); @@ -681,8 +680,8 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Start process tracing::info!("Starting download process."); let mut download_process = match Command::new("text-generation-server") - .args(download_argv) - .envs(env) + .args(download_args) + .envs(envs) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -693,6 +692,8 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L if err.kind() == io::ErrorKind::NotFound { tracing::error!("text-generation-server not found in PATH"); tracing::error!("Please install it with `make install-server`") + } else { + tracing::error!("{}", err); } return Err(LauncherError::DownloadError); @@ -701,16 +702,10 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // Redirect STDOUT to the console let download_stdout = download_process.stdout.take().unwrap(); + let stdout = BufReader::new(download_stdout); + thread::spawn(move || { - // Enter download tracing span - let stdout = BufReader::new(download_stdout); - let _span = tracing::span!(tracing::Level::INFO, "download").entered(); - for line in stdout.lines() { - // Parse loguru logs - if let Ok(log) = serde_json::from_str::(&line.unwrap()) { - log.trace(); - } - } + log_lines(stdout.lines()); }); loop { @@ -738,10 +733,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L return Err(LauncherError::DownloadError); } if !running.load(Ordering::SeqCst) { - signal::kill(Pid::from_raw(download_process.id() as i32), Signal::SIGTERM).unwrap(); - tracing::info!("Waiting for download process to gracefully shutdown"); - download_process.wait().unwrap(); - tracing::info!("Download process terminated"); + terminate("download", download_process, Duration::from_secs(10)).unwrap(); return Ok(()); } sleep(Duration::from_millis(100)); @@ -760,16 +752,6 @@ fn spawn_shards( status_sender: mpsc::Sender, running: Arc, ) -> Result<(), LauncherError> { - if args.trust_remote_code { - tracing::warn!( - "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", - args.model_id - ); - if args.revision.is_none() { - tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision."); - } - } - // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); @@ -828,11 +810,8 @@ fn spawn_shards( Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100)); } - Ok(ShardStatus::Failed((rank, err))) => { + Ok(ShardStatus::Failed(rank)) => { tracing::error!("Shard {rank} failed to start"); - if let Some(err) = err { - tracing::error!("{err}"); - } shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardCannotStart); } @@ -854,7 +833,7 @@ fn spawn_webserver( // All shard started // Start webserver tracing::info!("Starting Webserver"); - let mut argv = vec![ + let mut router_args = vec![ "--max-concurrent-requests".to_string(), args.max_concurrent_requests.to_string(), "--max-best-of".to_string(), @@ -867,8 +846,6 @@ fn spawn_webserver( args.max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), args.max_batch_prefill_tokens.to_string(), - "--max-batch-total-tokens".to_string(), - args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -885,63 +862,54 @@ fn spawn_webserver( args.model_id, ]; + // Model optional max batch total tokens + if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { + router_args.push("--max-batch-total-tokens".to_string()); + router_args.push(max_batch_total_tokens.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { - argv.push("--revision".to_string()); - argv.push(revision.to_string()) + router_args.push("--revision".to_string()); + router_args.push(revision.to_string()) } if args.json_output { - argv.push("--json-output".to_string()); + router_args.push("--json-output".to_string()); } // OpenTelemetry if let Some(otlp_endpoint) = args.otlp_endpoint { - argv.push("--otlp-endpoint".to_string()); - argv.push(otlp_endpoint); + router_args.push("--otlp-endpoint".to_string()); + router_args.push(otlp_endpoint); } // CORS origins for origin in args.cors_allow_origin.into_iter() { - argv.push("--cors-allow-origin".to_string()); - argv.push(origin); + router_args.push("--cors-allow-origin".to_string()); + router_args.push(origin); } // Ngrok if args.ngrok { - let authtoken = args.ngrok_authtoken.ok_or_else(|| { - tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling"); - LauncherError::WebserverCannotStart - })?; - - argv.push("--ngrok".to_string()); - argv.push("--ngrok-authtoken".to_string()); - argv.push(authtoken); - - if let Some(domain) = args.ngrok_domain { - argv.push("--ngrok-domain".to_string()); - argv.push(domain); - } - - if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) { - argv.push("--ngrok-username".to_string()); - argv.push(username); - argv.push("--ngrok-password".to_string()); - argv.push(password); - } + router_args.push("--ngrok".to_string()); + router_args.push("--ngrok-authtoken".to_string()); + router_args.push(args.ngrok_authtoken.unwrap()); + router_args.push("--ngrok-edge".to_string()); + router_args.push(args.ngrok_edge.unwrap()); } // Copy current process env - let mut env: Vec<(OsString, OsString)> = env::vars_os().collect(); + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); // Parse Inference API token if let Ok(api_token) = env::var("HF_API_TOKEN") { - env.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) + envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; let mut webserver = match Command::new("text-generation-router") - .args(argv) - .envs(env) + .args(router_args) + .envs(envs) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -979,14 +947,49 @@ fn spawn_webserver( Ok(webserver) } +fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result { + tracing::info!("Terminating {process_name}"); + + let terminate_time = Instant::now(); + signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap(); + + tracing::info!("Waiting for {process_name} to gracefully shutdown"); + + while terminate_time.elapsed() < timeout { + if let Some(status) = process.try_wait()? { + tracing::info!("{process_name} terminated"); + return Ok(status); + } + sleep(Duration::from_millis(100)); + } + + tracing::info!("Killing {process_name}"); + + process.kill()?; + let exit_status = process.wait()?; + + tracing::info!("{process_name} killed"); + Ok(exit_status) +} + fn main() -> Result<(), LauncherError> { // Pattern match configuration - let args = Args::parse(); + let args: Args = Args::parse(); + + // Filter events with LOG_LEVEL + let env_filter = + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); if args.json_output { - tracing_subscriber::fmt().json().init(); + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .json() + .init(); } else { - tracing_subscriber::fmt().compact().init(); + tracing_subscriber::fmt() + .with_env_filter(env_filter) + .compact() + .init(); } if args.env { @@ -1008,29 +1011,53 @@ fn main() -> Result<(), LauncherError> { args.max_batch_prefill_tokens, args.max_input_length ))); } - if args.max_batch_prefill_tokens > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_batch_total_tokens - ))); - } - if args.max_total_tokens as u32 > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, args.max_batch_total_tokens - ))); - } + if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + } let num_shard = find_num_shards(args.sharded, args.num_shard)?; if num_shard > 1 { tracing::info!("Sharding model on {num_shard} processes"); } + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if args.max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if args.max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_total_tokens, max_batch_total_tokens + ))); + } + } + + if args.ngrok { + if args.ngrok_authtoken.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-authtoken` must be set when using ngrok tunneling".to_string(), + )); + } + + if args.ngrok_edge.is_none() { + return Err(LauncherError::ArgumentValidation( + "`ngrok-edge` must be set when using ngrok tunneling".to_string(), + )); + } + } + // Signal handler let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); @@ -1042,6 +1069,11 @@ fn main() -> Result<(), LauncherError> { // Download and convert model weights download_convert_model(&args, running.clone())?; + if !running.load(Ordering::SeqCst) { + // Launcher was asked to stop + return Ok(()); + } + // Shared shutdown bool let shutdown = Arc::new(AtomicBool::new(false)); // Shared shutdown channel @@ -1078,11 +1110,8 @@ fn main() -> Result<(), LauncherError> { let mut exit_code = Ok(()); while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { tracing::error!("Shard {rank} crashed"); - if let Some(err) = err { - tracing::error!("{err}"); - } exit_code = Err(LauncherError::ShardFailed); break; }; @@ -1100,10 +1129,7 @@ fn main() -> Result<(), LauncherError> { } // Graceful termination - signal::kill(Pid::from_raw(webserver.id() as i32), Signal::SIGTERM).unwrap(); - tracing::info!("Waiting for webserver to gracefully shutdown"); - webserver.wait().unwrap(); - tracing::info!("Webserver terminated"); + terminate("webserver", webserver, Duration::from_secs(90)).unwrap(); shutdown_shards(shutdown, &shutdown_receiver); exit_code diff --git a/proto/generate.proto b/proto/generate.proto index 5e061941..57d79bca 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -198,9 +198,10 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - /// Maximum number of tokens that the client will send - uint32 max_total_tokens = 2; } /// Empty response -message WarmupResponse {} +message WarmupResponse { + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index b9607a5d..7753f307 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,8 +103,7 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); @@ -143,13 +142,9 @@ impl Client { max_tokens: 0, }; - let request = tonic::Request::new(WarmupRequest { - batch: Some(batch), - max_total_tokens, - }) - .inject_context(); - self.stub.warmup(request).await?.into_inner(); - Ok(()) + let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) } /// Generate one token for each request in the given batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 9dd173a0..6d146bc5 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,14 +95,11 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) - }) + .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer.rs b/router/src/infer.rs index d0d22d3b..188ddc64 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -53,7 +53,7 @@ impl Infer { generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding); + let queue = Queue::new(requires_padding, 16); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/main.rs b/router/src/main.rs index 57ddd5ba..059f8692 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -37,8 +37,8 @@ struct Args { waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "0.0.0.0", long, env)] @@ -49,8 +49,8 @@ struct Args { master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, - #[clap(default_value = "main", long, env)] - revision: String, + #[clap(long, env)] + revision: Option, #[clap(default_value = "2", long, env)] validation_workers: usize, #[clap(long, env)] @@ -64,11 +64,7 @@ struct Args { #[clap(long, env)] ngrok_authtoken: Option, #[clap(long, env)] - ngrok_domain: Option, - #[clap(long, env)] - ngrok_username: Option, - #[clap(long, env)] - ngrok_password: Option, + ngrok_edge: Option, } fn main() -> Result<(), RouterError> { @@ -96,9 +92,7 @@ fn main() -> Result<(), RouterError> { cors_allow_origin, ngrok, ngrok_authtoken, - ngrok_domain, - ngrok_username, - ngrok_password, + ngrok_edge, } = args; // Validate args @@ -110,18 +104,22 @@ fn main() -> Result<(), RouterError> { if max_input_length as u32 > max_batch_prefill_tokens { return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); } - if max_batch_prefill_tokens > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } + if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -147,7 +145,7 @@ fn main() -> Result<(), RouterError> { // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime let params = FromPretrainedParameters { - revision: revision.clone(), + revision: revision.clone().unwrap_or("main".to_string()), auth_token: authorization_token.clone(), ..Default::default() }; @@ -175,7 +173,7 @@ fn main() -> Result<(), RouterError> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token) + false => get_model_info(&tokenizer_name, revision, authorization_token) .await .unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); @@ -210,14 +208,35 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); - sharded_client - .warmup( - max_input_length as u32, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) + let max_supported_batch_total_tokens = match sharded_client + .warmup(max_input_length as u32, max_batch_prefill_tokens) .await - .map_err(RouterError::Warmup)?; + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); let addr = match hostname.parse() { @@ -240,7 +259,7 @@ fn main() -> Result<(), RouterError> { max_total_tokens, waiting_served_ratio, max_batch_prefill_tokens, - max_batch_total_tokens, + max_supported_batch_total_tokens, max_waiting_tokens, sharded_client, tokenizer, @@ -249,9 +268,7 @@ fn main() -> Result<(), RouterError> { cors_allow_origin, ngrok, ngrok_authtoken, - ngrok_domain, - ngrok_username, - ngrok_password, + ngrok_edge, ) .await?; Ok(()) @@ -316,9 +333,18 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { /// get model info from the Huggingface Hub pub async fn get_model_info( model_id: &str, - revision: &str, + revision: Option, token: Option, ) -> Option { + let revision = match revision { + None => { + tracing::warn!("`--revision` is not set"); + tracing::warn!("We strongly advise to set it to a known supported commit."); + "main".to_string() + } + Some(revision) => revision, + }; + let client = reqwest::Client::new(); // Poor man's urlencode let revision = revision.replace('/', "%2F"); @@ -331,9 +357,18 @@ pub async fn get_model_info( let response = builder.send().await.ok()?; if response.status().is_success() { - return serde_json::from_str(&response.text().await.ok()?).ok(); + let hub_model_info: HubModelInfo = + serde_json::from_str(&response.text().await.ok()?).ok()?; + if let Some(sha) = &hub_model_info.sha { + tracing::info!( + "Serving revision {sha} of model {}", + hub_model_info.model_id + ); + } + Some(hub_model_info) + } else { + None } - None } #[derive(Debug, Error)] diff --git a/router/src/queue.rs b/router/src/queue.rs index 48e483a1..2d8d6d1c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -33,12 +33,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, queue_receiver)); + tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); Self { queue_sender } } @@ -81,8 +81,12 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(requires_padding: bool, receiver: flume::Receiver) { - let mut state = State::new(requires_padding); +async fn queue_task( + requires_padding: bool, + block_size: u32, + receiver: flume::Receiver, +) { + let mut state = State::new(requires_padding, block_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -119,15 +123,19 @@ struct State { /// Whether the model is using padding requires_padding: bool, + + /// Paged Attention block size + block_size: u32, } impl State { - fn new(requires_padding: bool) -> Self { + fn new(requires_padding: bool, block_size: u32) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, + block_size, } } @@ -187,10 +195,21 @@ impl State { max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { - prefill_tokens += entry.request.input_length; + // pad to block size + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + // pad to block size + decode_tokens += + ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) + / self.block_size) + * self.block_size; + } if prefill_tokens > prefill_token_budget || (prefill_tokens + decode_tokens) > token_budget @@ -321,7 +340,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -337,7 +356,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false); + let mut state = State::new(false, 1); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -345,7 +364,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -377,7 +396,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -410,14 +429,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -425,7 +444,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -458,7 +477,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -483,7 +502,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index 8ca463c2..bfeee375 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -524,9 +524,7 @@ pub async fn run( allow_origin: Option, ngrok: bool, ngrok_authtoken: Option, - ngrok_domain: Option, - ngrok_username: Option, - ngrok_password: Option, + ngrok_edge: Option, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -696,32 +694,25 @@ pub async fn run( #[cfg(feature = "ngrok")] { use ngrok::config::TunnelBuilder; - use ngrok::tunnel::UrlTunnel; let _ = addr; let authtoken = ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling"); - let mut tunnel = ngrok::Session::builder() + let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling"); + + let tunnel = ngrok::Session::builder() .authtoken(authtoken) .connect() .await .unwrap() - .http_endpoint(); - - if let Some(domain) = ngrok_domain { - tunnel = tunnel.domain(domain); - } - - if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) { - tunnel = tunnel.basic_auth(username, password); - } + .labeled_tunnel() + .label("edge", edge); let listener = tunnel.listen().await.unwrap(); // Run server - tracing::info!("Ingress URL: {:?}", listener.url()); axum::Server::builder(listener) .serve(app.into_make_service()) //Wait until all requests are finished to shut down diff --git a/server/Makefile b/server/Makefile index d0086928..0dc0b5c9 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ include Makefile-flash-att +include Makefile-flash-att-v2 include Makefile-vllm unit-tests: diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 new file mode 100644 index 00000000..a7d63356 --- /dev/null +++ b/server/Makefile-flash-att-v2 @@ -0,0 +1,13 @@ +flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc + +flash-attention-v2: + # Clone flash attention + pip install packaging + git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + +build-flash-attention-v2: flash-attention-v2 + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) + cd flash-attention-v2 && python setup.py build + +install-flash-attention-v2: build-flash-attention-v2 + cd flash-attention-v2 && python setup.py install \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index a696d7be..be79da51 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "0.9.1" +version = "0.9.3" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 7a55e919..e74c0331 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -194,6 +194,8 @@ def quantize( percdamp: float = 0.01, act_order: bool = False, ): + if revision is None: + revision = "main" download_weights( model_id=model_id, revision=revision, @@ -207,6 +209,7 @@ def quantize( bits=4, groupsize=128, output_dir=output_dir, + revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fd97f8b1..ffc224cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -42,51 +42,21 @@ __all__ = [ "get_model", ] -FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA and Flash Attention kernels to be installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" -) +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +FLASH_ATTENTION = True try: - if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - if not torch.cuda.is_available(): - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA. No compatible CUDA devices found." - ) - raise ImportError("CUDA is not available") - - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - supported = is_sm75 or is_sm8x or is_sm90 - if not supported: - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. " - "No compatible CUDA device found." - ) - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) - - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, - ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, - ) - - FLASH_ATTENTION = True - else: - FLASH_ATTENTION = False -except ImportError: - logger.opt(exception=True).warning( - "Could not import Flash Attention enabled models" + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, + ) + +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 626404e6..039fe2bf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,25 +23,77 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda import dropout_layer_norm # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, TensorParallelHead, + get_linear, ) +class LlamaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_scaling = rope_scaling + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + class LlamaRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): """ @@ -59,7 +111,8 @@ class LlamaRMSNorm(nn.Module): hidden_states += residual residual = hidden_states - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) @@ -94,6 +147,27 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def _load_gqa(config, prefix: str, weights): + w = [ + weights.get_sharded(f"{prefix}.q_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.k_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.v_proj.weight", dim=0), + ] + weight = torch.cat(w, dim=0) + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + bias = None + assert config.hidden_size % config.num_attention_heads == 0 + head_size = config.hidden_size // config.num_attention_heads + assert config.num_attention_heads % weights.process_group.size() == 0 + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -118,22 +192,29 @@ class FlashLlamaAttention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() ) + if config.num_attention_heads != config.num_key_value_heads: + self.query_key_value = _load_gqa(config, prefix, weights) + else: + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( - 0, self.num_heads, dtype=torch.int32, device=weights.device - ) + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -148,38 +229,37 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) vllm_cache_ops.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: @@ -187,7 +267,7 @@ class FlashLlamaAttention(torch.nn.Module): block_size = kv_cache[1].shape[3] vllm_attention_ops.single_query_cached_kv_attention( attn_output, - qkv[:, 0], + query, kv_cache[0], kv_cache[1], self.kv_head_mapping, @@ -323,6 +403,7 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b2dce226..e7c8ced4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index acac2744..1e9539c4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if self.num_heads_kv == 1: - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: @@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand to query shape - kv = ( - kv.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 43dc3606..4dd76360 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,13 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -265,26 +263,15 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand from 1 to num_heads - key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 78db35f0..56c21463 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -712,14 +712,14 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + def warmup(self, batch: FlashCausalLMBatch): global CACHE_MANAGER torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(self.device) try: CACHE_MANAGER = CacheManager( - # Adds some wiggle room - math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, @@ -729,11 +729,43 @@ class FlashCausalLM(Model): _, batch = self.generate_token(batch) except Exception as e: raise RuntimeError( - f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " - f"prefill tokens. " - f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" ) from e + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize(self.device) + peak_memory = torch.cuda.max_memory_reserved(self.device) + + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + + # 0.98 to add some wiggle room + num_blocks = ( + int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + batch.blocks + ) + + del CACHE_MANAGER del batch + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 3fd17a73..77450cbb 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,13 +2,13 @@ import torch import torch.distributed from opentelemetry import trace -from transformers import AutoConfig from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, + LlamaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = AutoConfig.from_pretrained( + config = LlamaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -70,7 +70,7 @@ class FlashLlama(FlashCausalLM): tokenizer=tokenizer, config=config, num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, + num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c686506b..8ceac511 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -107,8 +107,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def warmup(self, batch: B, max_total_tokens: int): + def warmup(self, batch: B) -> Optional[int]: self.generate_token(batch) + return None def decode_token( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7bc62ce6..e0efbcf5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - self.model.warmup(batch, request.max_total_tokens) + max_supported_total_tokens = self.model.warmup(batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return generate_pb2.WarmupResponse() + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( @@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) - if torch.cuda.is_available(): - torch.cuda.empty_cache() else: batch = batches[0] diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 305263ba..8d414eca 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -94,6 +94,14 @@ def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: Lis # We do this instead of using tqdm because we want to parse the logs with the launcher for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): + # Skip blacklisted files + if ( + "arguments" in pt_file.name + or "args" in pt_file.name + or "training" in pt_file.name + ): + continue + start = datetime.datetime.now() convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py new file mode 100644 index 00000000..c472d1fc --- /dev/null +++ b/server/text_generation_server/utils/flash_attn.py @@ -0,0 +1,124 @@ +import os +import torch + +from loguru import logger + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, +): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + if HAS_FLASH_ATTN: + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed") diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 5a4ed8da..d182456f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -13,6 +13,9 @@ import transformers from huggingface_hub import HfApi import numpy as np import torch +from accelerate import init_empty_weights +from text_generation_server.utils import initialize_torch_distributed, Weights +from text_generation_server.utils.hub import weight_files from text_generation_server.utils.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional @@ -38,7 +41,6 @@ class Quantizer(nn.Module): maxshrink=0.8, trits=False, ): - self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -600,6 +602,8 @@ def sequential( nsamples, bits, groupsize, + *, + hooks, percdamp=0.01, sym: bool = False, act_order: bool = False, @@ -637,7 +641,7 @@ def sequential( layers[0] = Catcher(layers[0]) for batch in dataloader: try: - model(batch[0]) + model(batch[0].cuda()) except ValueError: pass layers[0] = layers[0].module @@ -646,6 +650,8 @@ def sequential( # model.model.embed_tokens = model.model.embed_tokens.cpu() # model.model.norm = model.model.norm.cpu() torch.cuda.empty_cache() + for hook in hooks: + hook.remove() outs = torch.zeros_like(inps) @@ -662,10 +668,8 @@ def sequential( print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") print("+==================+==============+============+===========+=======+") - from accelerate.hooks import remove_hook_from_submodules - - layer = layers[i].to(dev) - remove_hook_from_submodules(layer) + layer = layers[i] + layer.load() full = find_layers(layer) sequential = [list(full.keys())] @@ -677,6 +681,7 @@ def sequential( gptq[name].quantizer.configure( bits, perchannel=True, sym=sym, mse=False ) + pass def add_batch(name): def tmp(_, inp, out): @@ -688,7 +693,6 @@ def sequential( for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] for h in handles: h.remove() @@ -714,7 +718,7 @@ def sequential( for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0), **extra)[0] - layers[i] = layer.cpu() + layer.unload() del layer del gptq torch.cuda.empty_cache() @@ -768,24 +772,136 @@ def pack(model, quantizers, bits, groupsize): return model +def setdeepattr(module, full_name, tensor): + current = module + tokens = full_name.split(".") + for token in tokens[:-1]: + current = getattr(current, token) + setattr(current, tokens[-1], tensor) + + +def getdeepattr(module, full_name): + current = module + tokens = full_name.split(".") + for token in tokens: + current = getattr(current, token) + return current + + +def load_weights_pre_hook(module_name, weights, recursive=False): + def inner(module, args): + print(f"Pre hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + + for local_param in local_params: + current_tensor = getdeepattr(module, local_param) + if current_tensor.device == torch.device("meta"): + # print(f"Loading {local_param}") + if module_name: + tensor_name = f"{module_name}.{local_param}" + else: + tensor_name = local_param + tensor = weights.get_tensor(tensor_name) + setdeepattr(module, local_param, nn.Parameter(tensor)) + else: + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))), + ) + + return inner + + +def load_weights_post_hook(module_name, weights, recursive=False): + def inner(module, args, output): + print(f"Post hook {module_name}") + local_params = {} + for k, v in module.named_parameters(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for k, v in module.named_buffers(): + if not recursive and k.count(".") != 1: + continue + local_params[k] = v + for local_param in local_params: + # print(f"Unloading {local_param}") + current_tensor = getdeepattr(module, local_param) + setdeepattr( + module, + local_param, + nn.Parameter(current_tensor.to(device=torch.device("cpu"))), + ) + return output + + return inner + + def quantize( model_id: str, bits: int, groupsize: int, output_dir: str, + revision: str, trust_remote_code: bool, upload_to_model_id: Optional[str], percdamp: float, act_order: bool, ): print("loading model") - model = AutoModelForCausalLM.from_pretrained( + config = AutoConfig.from_pretrained( model_id, - torch_dtype=torch.float16, - device_map="balanced_low_0", trust_remote_code=trust_remote_code, ) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) + model = model.eval() + print("LOADED model") + files = weight_files(model_id, revision, extension=".safetensors") + process_group, _, _ = initialize_torch_distributed() + weights = Weights( + files, + device=torch.device("cuda:0"), + dtype=torch.float16, + process_group=process_group, + aliases={"embed_tokens.weight": ["lm_head.weight"]}, + ) + hooks = [] + for name, module in model.named_modules(): + + def load(module, name): + def _load(): + load_weights_pre_hook(name, weights, recursive=True)(module, None) + + return _load + + def unload(module, name): + def _unload(): + load_weights_post_hook(name, weights, recursive=True)( + module, None, None + ) + + return _unload + + module.load = load(module, name) + module.unload = unload(module, name) + hooks.append( + module.register_forward_pre_hook(load_weights_pre_hook(name, weights)) + ) + hooks.append( + module.register_forward_hook(load_weights_post_hook(name, weights)) + ) model.seqlen = 2048 dataset = "wikitext2" @@ -806,6 +922,7 @@ def quantize( groupsize, percdamp=percdamp, act_order=act_order, + hooks=hooks, ) print(time.time() - tick) @@ -858,7 +975,6 @@ def quantize( logger.info("Saved tokenizer") if upload_to_model_id: - api = HfApi() api.upload_folder(